Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
ca46b9d2
Commit
ca46b9d2
authored
Jul 16, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
cleanup of unnecessary docstrings
parent
9c26a6d8
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
10 additions
and
333 deletions
+10
-333
torch_harmonics/distributed/distributed_resample.py
torch_harmonics/distributed/distributed_resample.py
+0
-3
torch_harmonics/distributed/distributed_sht.py
torch_harmonics/distributed/distributed_sht.py
+1
-37
torch_harmonics/distributed/primitives.py
torch_harmonics/distributed/primitives.py
+0
-23
torch_harmonics/examples/losses.py
torch_harmonics/examples/losses.py
+1
-75
torch_harmonics/examples/metrics.py
torch_harmonics/examples/metrics.py
+1
-1
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+2
-37
torch_harmonics/examples/stanford_2d3ds_dataset.py
torch_harmonics/examples/stanford_2d3ds_dataset.py
+1
-50
torch_harmonics/random_fields.py
torch_harmonics/random_fields.py
+0
-29
torch_harmonics/resample.py
torch_harmonics/resample.py
+2
-28
torch_harmonics/sht.py
torch_harmonics/sht.py
+2
-50
No files found.
torch_harmonics/distributed/distributed_resample.py
View file @
ca46b9d2
...
...
@@ -153,9 +153,6 @@ class DistributedResampleS2(nn.Module):
self
.
skip_resampling
=
(
nlon_in
==
nlon_out
)
and
(
nlat_in
==
nlat_out
)
and
(
grid_in
==
grid_out
)
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
"
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
...
...
torch_harmonics/distributed/distributed_sht.py
View file @
ca46b9d2
...
...
@@ -77,27 +77,7 @@ class DistributedRealSHT(nn.Module):
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
norm
=
"ortho"
,
csphase
=
True
):
"""
Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
"""
super
().
__init__
()
self
.
nlat
=
nlat
...
...
@@ -369,22 +349,6 @@ class DistributedRealVectorSHT(nn.Module):
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
norm
=
"ortho"
,
csphase
=
True
):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters
----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
"""
super
().
__init__
()
...
...
torch_harmonics/distributed/primitives.py
View file @
ca46b9d2
...
...
@@ -102,25 +102,6 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
def
_transpose
(
tensor
,
dim0
,
dim1
,
dim1_split_sizes
,
group
=
None
,
async_op
=
False
):
"""
Transpose a tensor along two dimensions.
Parameters
----------
tensor: torch.Tensor
The tensor to transpose
dim0: int
The first dimension to transpose
dim1: int
The second dimension to transpose
dim1_split_sizes: List[int]
The split sizes for the second dimension
Returns
-------
tensor_list: List[torch.Tensor]
The split tensors
"""
# get comm params
comm_size
=
dist
.
get_world_size
(
group
=
group
)
...
...
@@ -198,7 +179,6 @@ class distributed_transpose_polar(torch.autograd.Function):
# we need those additional primitives for distributed matrix multiplications
def
_reduce
(
input_
,
use_fp32
=
True
,
group
=
None
):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if
dist
.
get_world_size
(
group
=
group
)
==
1
:
...
...
@@ -219,7 +199,6 @@ def _reduce(input_, use_fp32=True, group=None):
def
_split
(
input_
,
dim_
,
group
=
None
):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# Bypass the function if we are using only 1 GPU.
comm_size
=
dist
.
get_world_size
(
group
=
group
)
if
comm_size
==
1
:
...
...
@@ -236,7 +215,6 @@ def _split(input_, dim_, group=None):
def
_gather
(
input_
,
dim_
,
shapes_
,
group
=
None
):
"""Gather unevenly split tensors across ranks"""
comm_size
=
dist
.
get_world_size
(
group
=
group
)
...
...
@@ -269,7 +247,6 @@ def _gather(input_, dim_, shapes_, group=None):
def
_reduce_scatter
(
input_
,
dim_
,
use_fp32
=
True
,
group
=
None
):
"""All-reduce the input tensor across model parallel group and scatter it back."""
# Bypass the function if we are using only 1 GPU.
if
dist
.
get_world_size
(
group
=
group
)
==
1
:
...
...
torch_harmonics/examples/losses.py
View file @
ca46b9d2
...
...
@@ -275,19 +275,10 @@ class SphericalLossBase(nn.Module, ABC):
@
abstractmethod
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Abstract method that must be implemented by child classes to compute loss terms.
Args:
prd (torch.Tensor): Prediction tensor
tar (torch.Tensor): Target tensor
Returns:
torch.Tensor: Computed loss term before integration
"""
pass
def
_post_integration_hook
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Post-integration hook. Commonly used for the roots in Lp norms"""
return
loss
def
forward
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -309,21 +300,7 @@ class SquaredL2LossS2(SphericalLossBase):
"""
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute squared L2 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Squared difference between prediction and target
"""
return
torch
.
square
(
prd
-
tar
)
...
...
@@ -335,21 +312,7 @@ class L1LossS2(SphericalLossBase):
"""
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute L1 loss term.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Absolute difference between prediction and target
"""
return
torch
.
abs
(
prd
-
tar
)
...
...
@@ -361,19 +324,7 @@ class L2LossS2(SquaredL2LossS2):
"""
def
_post_integration_hook
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Apply square root to get L2 norm.
Parameters
-----------
loss : torch.Tensor
Integrated squared loss
Returns
-------
torch.Tensor
Square root of the loss (L2 norm)
"""
return
torch
.
sqrt
(
loss
)
...
...
@@ -385,18 +336,7 @@ class W11LossS2(SphericalLossBase):
"""
def
__init__
(
self
,
nlat
:
int
,
nlon
:
int
,
grid
:
str
=
"equiangular"
):
"""
Initialize W11 loss.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
"""
super
().
__init__
(
nlat
=
nlat
,
nlon
=
nlon
,
grid
=
grid
)
# Set up grid and domain for FFT
l_phi
=
2
*
torch
.
pi
# domain size
...
...
@@ -512,21 +452,7 @@ class NormalLossS2(SphericalLossBase):
return
normals
def
_compute_loss_term
(
self
,
prd
:
torch
.
Tensor
,
tar
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Compute combined L1 and normal consistency loss.
Parameters
-----------
prd : torch.Tensor
Prediction tensor
tar : torch.Tensor
Target tensor
Returns
-------
torch.Tensor
Combined loss term
"""
# Handle dimensions for both prediction and target
# Ensure we have at least a batch dimension
if
prd
.
dim
()
==
2
:
...
...
torch_harmonics/examples/metrics.py
View file @
ca46b9d2
...
...
@@ -113,7 +113,7 @@ def _get_stats_multiclass(
def
_predict_classes
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Convert logits to class predictions using softmax and argmax.
Parameters
...
...
torch_harmonics/examples/models/_layers.py
View file @
ca46b9d2
...
...
@@ -41,46 +41,11 @@ from torch_harmonics import InverseRealSHT
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
"""
Initialize tensor with truncated normal distribution without gradients.
This is a helper function for trunc_normal_ that performs the actual initialization
without requiring gradients to be tracked.
Parameters
-----------
tensor : torch.Tensor
Tensor to initialize
mean : float
Mean of the normal distribution
std : float
Standard deviation of the normal distribution
a : float
Lower bound for truncation
b : float
Upper bound for truncation
Returns
-------
torch.Tensor
Initialized tensor
"""
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
"""
Compute standard normal cumulative distribution function.
Parameters
-----------
x : float
Input value
Returns
-------
float
CDF value
"""
# Computes standard normal cumulative distribution function
return
(
1.0
+
math
.
erf
(
x
/
math
.
sqrt
(
2.0
)))
/
2.0
...
...
torch_harmonics/examples/stanford_2d3ds_dataset.py
View file @
ca46b9d2
...
...
@@ -80,34 +80,13 @@ class Stanford2D3DSDownloader:
"""
def
__init__
(
self
,
base_url
:
str
=
DEFAULT_BASE_URL
,
local_dir
:
str
=
"data"
):
"""
Initialize the Stanford 2D3DS dataset downloader.
Parameters
-----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
"""
self
.
base_url
=
base_url
self
.
local_dir
=
local_dir
os
.
makedirs
(
self
.
local_dir
,
exist_ok
=
True
)
def
_download_file
(
self
,
filename
):
"""
Download a single file with progress bar and resume capability.
Parameters
-----------
filename : str
Name of the file to download
Returns
-------
str
Local path to the downloaded file
"""
import
requests
from
tqdm
import
tqdm
...
...
@@ -141,19 +120,7 @@ class Stanford2D3DSDownloader:
return
local_path
def
_extract_tar
(
self
,
tar_path
):
"""
Extract a tar file and return the extracted directory name.
Parameters
-----------
tar_path : str
Path to the tar file to extract
Returns
-------
str
Name of the extracted directory
"""
import
tarfile
with
tarfile
.
open
(
tar_path
)
as
tar
:
...
...
@@ -194,23 +161,7 @@ class Stanford2D3DSDownloader:
return
data_folders
,
class_labels
def
_rgb_to_id
(
self
,
img
,
class_labels_map
,
class_labels_indices
):
"""
Convert RGB image to class ID using color mapping.
Parameters
-----------
img : numpy.ndarray
RGB image array
class_labels_map : list
Mapping from color values to class labels
class_labels_indices : list
List of class label indices
Returns
-------
numpy.ndarray
Class ID array
"""
# Convert to int32 first to avoid overflow
r
=
img
[...,
0
].
astype
(
np
.
int32
)
g
=
img
[...,
1
].
astype
(
np
.
int32
)
...
...
torch_harmonics/random_fields.py
View file @
ca46b9d2
...
...
@@ -35,35 +35,6 @@ from .sht import InverseRealSHT
class
GaussianRandomFieldS2
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
nlat
,
alpha
=
2.0
,
tau
=
3.0
,
sigma
=
None
,
radius
=
1.0
,
grid
=
"equiangular"
,
dtype
=
torch
.
float32
):
super
().
__init__
()
r
"""
A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self
.
nlat
=
nlat
...
...
torch_harmonics/resample.py
View file @
ca46b9d2
...
...
@@ -137,25 +137,11 @@ class ResampleS2(nn.Module):
def
extra_repr
(
self
):
r
"""
Pretty print module
"""
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
"
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
"""
Interpolate the input tensor along the longitude dimension.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Interpolated tensor along longitude dimension
"""
# do the interpolation in precision of x
lwgt
=
self
.
lon_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
...
...
@@ -192,19 +178,7 @@ class ResampleS2(nn.Module):
return
x
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
"""
Interpolate the input tensor along the latitude dimension.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Interpolated tensor along latitude dimension
"""
# do the interpolation in precision of x
lwgt
=
self
.
lat_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
...
...
torch_harmonics/sht.py
View file @
ca46b9d2
...
...
@@ -72,31 +72,7 @@ class RealSHT(nn.Module):
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
norm
=
"ortho"
,
csphase
=
True
):
r
"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
"""
super
().
__init__
()
...
...
@@ -314,31 +290,7 @@ class RealVectorSHT(nn.Module):
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
norm
=
"ortho"
,
csphase
=
True
):
r
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
"""
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment