Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e4879676
Commit
e4879676
authored
Jun 26, 2025
by
apaaris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
Added docstrings to many methods
parent
b5c410c0
Changes
29
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
815 additions
and
43 deletions
+815
-43
torch_harmonics/examples/models/s2unet.py
torch_harmonics/examples/models/s2unet.py
+122
-1
torch_harmonics/examples/models/sfno.py
torch_harmonics/examples/models/sfno.py
+58
-11
torch_harmonics/examples/pde_sphere.py
torch_harmonics/examples/pde_sphere.py
+124
-11
torch_harmonics/examples/shallow_water_equations.py
torch_harmonics/examples/shallow_water_equations.py
+136
-12
torch_harmonics/examples/stanford_2d3ds_dataset.py
torch_harmonics/examples/stanford_2d3ds_dataset.py
+96
-4
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+91
-4
torch_harmonics/plotting.py
torch_harmonics/plotting.py
+86
-0
torch_harmonics/quadrature.py
torch_harmonics/quadrature.py
+26
-0
torch_harmonics/resample.py
torch_harmonics/resample.py
+76
-0
No files found.
torch_harmonics/examples/models/s2unet.py
View file @
e4879676
...
@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
...
@@ -54,6 +54,46 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
class
DownsamplingBlock
(
nn
.
Module
):
class
DownsamplingBlock
(
nn
.
Module
):
"""
Downsampling block for spherical U-Net architecture.
This block performs convolution operations followed by downsampling on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
downsampling_mode : str, optional
Downsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -154,12 +194,33 @@ class DownsamplingBlock(nn.Module):
...
@@ -154,12 +194,33 @@ class DownsamplingBlock(nn.Module):
self
.
apply
(
self
.
_init_weights
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Forward pass of the downsampling block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Downsampled tensor
"""
# skip connection
# skip connection
residual
=
x
residual
=
x
if
hasattr
(
self
,
"transform_skip"
):
if
hasattr
(
self
,
"transform_skip"
):
...
@@ -178,6 +239,46 @@ class DownsamplingBlock(nn.Module):
...
@@ -178,6 +239,46 @@ class DownsamplingBlock(nn.Module):
class
UpsamplingBlock
(
nn
.
Module
):
class
UpsamplingBlock
(
nn
.
Module
):
"""
Upsampling block for spherical U-Net architecture.
This block performs upsampling followed by convolution operations on spherical data,
using discrete-continuous convolutions to maintain spectral properties.
Parameters
-----------
in_shape : tuple
Input shape (nlat, nlon)
out_shape : tuple
Output shape (nlat, nlon)
in_channels : int
Number of input channels
out_channels : int
Number of output channels
grid_in : str, optional
Input grid type, by default "equiangular"
grid_out : str, optional
Output grid type, by default "equiangular"
nrep : int, optional
Number of convolution repetitions, by default 1
kernel_shape : tuple, optional
Kernel shape for convolution, by default (3, 3)
basis_type : str, optional
Filter basis type, by default "morlet"
activation : nn.Module, optional
Activation function, by default nn.ReLU
transform_skip : bool, optional
Whether to transform skip connection, by default False
drop_conv_rate : float, optional
Dropout rate for convolutions, by default 0.0
drop_path_rate : float, optional
Drop path rate, by default 0.0
drop_dense_rate : float, optional
Dropout rate for dense layers, by default 0.0
upsampling_mode : str, optional
Upsampling mode ("bilinear", "conv"), by default "bilinear"
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_shape
,
in_shape
,
...
@@ -496,6 +597,14 @@ class SphericalUNet(nn.Module):
...
@@ -496,6 +597,14 @@ class SphericalUNet(nn.Module):
self
.
apply
(
self
.
_init_weights
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
def
_init_weights
(
self
,
m
):
"""
Initialize weights for the module.
Parameters
-----------
m : nn.Module
Module to initialize
"""
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
...
@@ -505,7 +614,19 @@ class SphericalUNet(nn.Module):
...
@@ -505,7 +614,19 @@ class SphericalUNet(nn.Module):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete spherical U-Net model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
# encoder:
# encoder:
features
=
[]
features
=
[]
feat
=
x
feat
=
x
...
...
torch_harmonics/examples/models/sfno.py
View file @
e4879676
...
@@ -118,9 +118,20 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
...
@@ -118,9 +118,20 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
else
:
else
:
raise
ValueError
(
f
"Unknown skip connection type
{
outer_skip
}
"
)
raise
ValueError
(
f
"Unknown skip connection type
{
outer_skip
}
"
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the SFNO block.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor after processing through the block
"""
x
,
residual
=
self
.
global_conv
(
x
)
x
,
residual
=
self
.
global_conv
(
x
)
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
...
@@ -147,8 +158,12 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -147,8 +158,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Parameters
Parameters
----------
----------
img_s
hap
e : tuple, optional
img_s
iz
e : tuple, optional
Shape of the input channels, by default (128, 256)
Shape of the input channels, by default (128, 256)
grid : str, optional
Input grid type, by default "equiangular"
grid_internal : str, optional
Internal grid type for computations, by default "legendre-gauss"
scale_factor : int, optional
scale_factor : int, optional
Scale factor to use, by default 3
Scale factor to use, by default 3
in_chans : int, optional
in_chans : int, optional
...
@@ -172,20 +187,20 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -172,20 +187,20 @@ class SphericalFourierNeuralOperator(nn.Module):
drop_path_rate : float, optional
drop_path_rate : float, optional
Dropout path rate, by default 0.0
Dropout path rate, by default 0.0
normalization_layer : str, optional
normalization_layer : str, optional
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "
instance_norm
"
Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "
none
"
hard_thresholding_fraction : float, optional
hard_thresholding_fraction : float, optional
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
Fraction of hard thresholding (frequency cutoff) to apply, by default 1.0
residual_prediction : bool, optional
residual_prediction : bool, optional
Whether to add a single large skip connection, by default
Tru
e
Whether to add a single large skip connection, by default
Fals
e
pos_embed :
bool
, optional
pos_embed :
str
, optional
Whether to use
positional embedding, by default
True
Type of
positional embedding
to use
, by default
"none"
bias : bool, optional
bias : bool, optional
Whether to use a bias, by default False
Whether to use a bias, by default False
Example:
Example:
--------
--------
>>> model = SphericalFourierNeuralOperator(
>>> model = SphericalFourierNeuralOperator(
... img_s
hap
e=(128, 256),
... img_s
iz
e=(128, 256),
... scale_factor=4,
... scale_factor=4,
... in_chans=2,
... in_chans=2,
... out_chans=2,
... out_chans=2,
...
@@ -355,10 +370,30 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -355,10 +370,30 @@ class SphericalFourierNeuralOperator(nn.Module):
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
"""
Return a set of parameter names that should not be decayed.
Returns
-------
set
Set of parameter names to exclude from weight decay
"""
return
{
"pos_embed.pos_embed"
}
def
forward_features
(
self
,
x
):
def
forward_features
(
self
,
x
):
"""
Forward pass through the feature extraction layers.
Parameters
-----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Features after processing through the network
"""
x
=
self
.
pos_drop
(
x
)
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
for
blk
in
self
.
blocks
:
...
@@ -367,7 +402,19 @@ class SphericalFourierNeuralOperator(nn.Module):
...
@@ -367,7 +402,19 @@ class SphericalFourierNeuralOperator(nn.Module):
return
x
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass through the complete SFNO model.
Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
if
self
.
residual_prediction
:
if
self
.
residual_prediction
:
residual
=
x
residual
=
x
...
...
torch_harmonics/examples/pde_sphere.py
View file @
e4879676
...
@@ -42,7 +42,27 @@ import numpy as np
...
@@ -42,7 +42,27 @@ import numpy as np
class
SphereSolver
(
nn
.
Module
):
class
SphereSolver
(
nn
.
Module
):
"""
"""
Solver class on the sphere. Can solve the following PDEs:
Solver class on the sphere. Can solve the following PDEs:
- Allen-Cahn eq
- Allen-Cahn equation
- Ginzburg-Landau equation
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere, by default 1.0
coeff : float, optional
Coefficient for the PDE, by default 0.001
"""
"""
def
__init__
(
self
,
nlat
,
nlon
,
dt
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
radius
=
1.0
,
coeff
=
0.001
):
def
__init__
(
self
,
nlat
,
nlon
,
dt
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
radius
=
1.0
,
coeff
=
0.001
):
...
@@ -97,17 +117,58 @@ class SphereSolver(nn.Module):
...
@@ -97,17 +117,58 @@ class SphereSolver(nn.Module):
self
.
register_buffer
(
'invlap'
,
invlap
)
self
.
register_buffer
(
'invlap'
,
invlap
)
def
grid2spec
(
self
,
u
):
def
grid2spec
(
self
,
u
):
"""spectral coefficients from spatial data"""
"""
Convert spatial data to spectral coefficients.
Parameters
-----------
u : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return
self
.
sht
(
u
)
return
self
.
sht
(
u
)
def
spec2grid
(
self
,
uspec
):
def
spec2grid
(
self
,
uspec
):
"""spatial data from spectral coefficients"""
"""
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return
self
.
isht
(
uspec
)
return
self
.
isht
(
uspec
)
def
dudtspec
(
self
,
uspec
,
pde
=
'allen-cahn'
):
def
dudtspec
(
self
,
uspec
,
pde
=
'allen-cahn'
):
"""
Compute the time derivative of spectral coefficients for different PDEs.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients
pde : str, optional
PDE type ("allen-cahn", "ginzburg-landau"), by default "allen-cahn"
Returns
-------
torch.Tensor
Time derivative of spectral coefficients
Raises
------
NotImplementedError
If PDE type is not supported
"""
if
pde
==
'allen-cahn'
:
if
pde
==
'allen-cahn'
:
ugrid
=
self
.
spec2grid
(
uspec
)
ugrid
=
self
.
spec2grid
(
uspec
)
u3spec
=
self
.
grid2spec
(
ugrid
**
3
)
u3spec
=
self
.
grid2spec
(
ugrid
**
3
)
...
@@ -117,20 +178,55 @@ class SphereSolver(nn.Module):
...
@@ -117,20 +178,55 @@ class SphereSolver(nn.Module):
u3spec
=
self
.
grid2spec
(
ugrid
**
3
)
u3spec
=
self
.
grid2spec
(
ugrid
**
3
)
dudtspec
=
uspec
+
(
1.
+
2.j
)
*
self
.
coeff
*
self
.
lap
*
uspec
-
(
1.
+
2.j
)
*
u3spec
dudtspec
=
uspec
+
(
1.
+
2.j
)
*
self
.
coeff
*
self
.
lap
*
uspec
-
(
1.
+
2.j
)
*
u3spec
else
:
else
:
NotImplementedError
raise
NotImplementedError
(
f
"PDE type
{
pde
}
not implemented"
)
return
dudtspec
return
dudtspec
def
randspec
(
self
):
def
randspec
(
self
):
"""random data on the sphere"""
"""
Generate random spectral data on the sphere.
Returns
-------
torch.Tensor
Random spectral coefficients
"""
rspec
=
torch
.
randn_like
(
self
.
lap
)
/
4
/
torch
.
pi
rspec
=
torch
.
randn_like
(
self
.
lap
)
/
4
/
torch
.
pi
return
rspec
return
rspec
def
plot_griddata
(
self
,
data
,
fig
,
cmap
=
'twilight_shifted'
,
vmax
=
None
,
vmin
=
None
,
projection
=
'3d'
,
title
=
None
,
antialiased
=
False
):
def
plot_griddata
(
self
,
data
,
fig
,
cmap
=
'twilight_shifted'
,
vmax
=
None
,
vmin
=
None
,
projection
=
'3d'
,
title
=
None
,
antialiased
=
False
):
"""
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
Plot data on the sphere grid. Requires cartopy for 3d plots.
Parameters
-----------
data : torch.Tensor
Data to plot
fig : matplotlib.figure.Figure
Figure to plot on
cmap : str, optional
Colormap name, by default 'twilight_shifted'
vmax : float, optional
Maximum value for color scaling, by default None
vmin : float, optional
Minimum value for color scaling, by default None
projection : str, optional
Projection type ("mollweide", "3d"), by default "3d"
title : str, optional
Plot title, by default None
antialiased : bool, optional
Whether to use antialiasing, by default False
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
Raises
------
NotImplementedError
If projection type is not supported
"""
"""
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
...
@@ -172,9 +268,26 @@ class SphereSolver(nn.Module):
...
@@ -172,9 +268,26 @@ class SphereSolver(nn.Module):
plt
.
title
(
title
,
y
=
1.05
)
plt
.
title
(
title
,
y
=
1.05
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
(
f
"Projection
{
projection
}
not implemented"
)
return
im
return
im
def
plot_specdata
(
self
,
data
,
fig
,
**
kwargs
):
def
plot_specdata
(
self
,
data
,
fig
,
**
kwargs
):
"""
Plot spectral data by converting to spatial data first.
Parameters
-----------
data : torch.Tensor
Spectral data to plot
fig : matplotlib.figure.Figure
Figure to plot on
**kwargs
Additional arguments passed to plot_griddata
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
return
self
.
plot_griddata
(
self
.
isht
(
data
),
fig
,
**
kwargs
)
return
self
.
plot_griddata
(
self
.
isht
(
data
),
fig
,
**
kwargs
)
torch_harmonics/examples/shallow_water_equations.py
View file @
e4879676
...
@@ -41,7 +41,35 @@ import numpy as np
...
@@ -41,7 +41,35 @@ import numpy as np
class
ShallowWaterSolver
(
nn
.
Module
):
class
ShallowWaterSolver
(
nn
.
Module
):
"""
"""
SWE solver class. Interface inspired bu pyspharm and SHTns
Shallow Water Equations (SWE) solver class for spherical geometry.
Interface inspired by pyspharm and SHTns. Solves the shallow water equations
on a rotating sphere using spectral methods.
Parameters
-----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
dt : float
Time step size
lmax : int, optional
Maximum l mode for spherical harmonics, by default None
mmax : int, optional
Maximum m mode for spherical harmonics, by default None
grid : str, optional
Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
radius : float, optional
Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
omega : float, optional
Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
gravity : float, optional
Gravitational acceleration in m/s², by default 9.80616
havg : float, optional
Average height in meters, by default 10.e3
hamp : float, optional
Height amplitude in meters, by default 120.
"""
"""
def
__init__
(
self
,
nlat
,
nlon
,
dt
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
radius
=
6.37122E6
,
\
def
__init__
(
self
,
nlat
,
nlon
,
dt
,
lmax
=
None
,
mmax
=
None
,
grid
=
"equiangular"
,
radius
=
6.37122E6
,
\
...
@@ -115,30 +143,82 @@ class ShallowWaterSolver(nn.Module):
...
@@ -115,30 +143,82 @@ class ShallowWaterSolver(nn.Module):
def
grid2spec
(
self
,
ugrid
):
def
grid2spec
(
self
,
ugrid
):
"""
"""
spectral coefficients from spatial data
Convert spatial data to spectral coefficients.
Parameters
-----------
ugrid : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
"""
return
self
.
sht
(
ugrid
)
return
self
.
sht
(
ugrid
)
def
spec2grid
(
self
,
uspec
):
def
spec2grid
(
self
,
uspec
):
"""
"""
spatial data from spectral coefficients
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
"""
return
self
.
isht
(
uspec
)
return
self
.
isht
(
uspec
)
def
vrtdivspec
(
self
,
ugrid
):
def
vrtdivspec
(
self
,
ugrid
):
"""spatial data from spectral coefficients"""
"""
Compute vorticity and divergence from velocity field.
Parameters
-----------
ugrid : torch.Tensor
Velocity field in spatial coordinates
Returns
-------
torch.Tensor
Spectral coefficients of vorticity and divergence
"""
vrtdivspec
=
self
.
lap
*
self
.
radius
*
self
.
vsht
(
ugrid
)
vrtdivspec
=
self
.
lap
*
self
.
radius
*
self
.
vsht
(
ugrid
)
return
vrtdivspec
return
vrtdivspec
def
getuv
(
self
,
vrtdivspec
):
def
getuv
(
self
,
vrtdivspec
):
"""
"""
compute wind vector from spectral coeffs of vorticity and divergence
Compute wind vector from spectral coefficients of vorticity and divergence.
Parameters
-----------
vrtdivspec : torch.Tensor
Spectral coefficients of vorticity and divergence
Returns
-------
torch.Tensor
Wind vector field in spatial coordinates
"""
"""
return
self
.
ivsht
(
self
.
invlap
*
vrtdivspec
/
self
.
radius
)
return
self
.
ivsht
(
self
.
invlap
*
vrtdivspec
/
self
.
radius
)
def
gethuv
(
self
,
uspec
):
def
gethuv
(
self
,
uspec
):
"""
"""
compute wind vector from spectral coeffs of vorticity and divergence
Compute height and wind vector from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Combined height and wind vector field
"""
"""
hgrid
=
self
.
spec2grid
(
uspec
[:
1
])
hgrid
=
self
.
spec2grid
(
uspec
[:
1
])
uvgrid
=
self
.
getuv
(
uspec
[
1
:])
uvgrid
=
self
.
getuv
(
uspec
[
1
:])
...
@@ -146,7 +226,17 @@ class ShallowWaterSolver(nn.Module):
...
@@ -146,7 +226,17 @@ class ShallowWaterSolver(nn.Module):
def
potential_vorticity
(
self
,
uspec
):
def
potential_vorticity
(
self
,
uspec
):
"""
"""
Compute potential vorticity
Compute potential vorticity from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Potential vorticity field
"""
"""
ugrid
=
self
.
spec2grid
(
uspec
)
ugrid
=
self
.
spec2grid
(
uspec
)
pvrt
=
(
0.5
*
self
.
havg
*
self
.
gravity
/
self
.
omega
)
*
(
ugrid
[
1
]
+
self
.
coriolis
)
/
ugrid
[
0
]
pvrt
=
(
0.5
*
self
.
havg
*
self
.
gravity
/
self
.
omega
)
*
(
ugrid
[
1
]
+
self
.
coriolis
)
/
ugrid
[
0
]
...
@@ -154,7 +244,17 @@ class ShallowWaterSolver(nn.Module):
...
@@ -154,7 +244,17 @@ class ShallowWaterSolver(nn.Module):
def
dimensionless
(
self
,
uspec
):
def
dimensionless
(
self
,
uspec
):
"""
"""
Remove dimensions from variables
Remove dimensions from variables for dimensionless analysis.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients with dimensions
Returns
-------
torch.Tensor
Dimensionless spectral coefficients
"""
"""
uspec
[
0
]
=
(
uspec
[
0
]
-
self
.
havg
*
self
.
gravity
)
/
self
.
hamp
/
self
.
gravity
uspec
[
0
]
=
(
uspec
[
0
]
-
self
.
havg
*
self
.
gravity
)
/
self
.
hamp
/
self
.
gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
...
@@ -163,9 +263,18 @@ class ShallowWaterSolver(nn.Module):
...
@@ -163,9 +263,18 @@ class ShallowWaterSolver(nn.Module):
def
dudtspec
(
self
,
uspec
):
def
dudtspec
(
self
,
uspec
):
"""
"""
Compute time derivatives from solution represented in spectral coefficients
Compute time derivatives from solution represented in spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Time derivatives of spectral coefficients
"""
"""
dudtspec
=
torch
.
zeros_like
(
uspec
)
dudtspec
=
torch
.
zeros_like
(
uspec
)
# compute the derivatives - this should be incorporated into the solver:
# compute the derivatives - this should be incorporated into the solver:
...
@@ -191,10 +300,15 @@ class ShallowWaterSolver(nn.Module):
...
@@ -191,10 +300,15 @@ class ShallowWaterSolver(nn.Module):
def
galewsky_initial_condition
(
self
):
def
galewsky_initial_condition
(
self
):
"""
"""
Initialize
s
non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
"""
"""
device
=
self
.
lap
.
device
device
=
self
.
lap
.
device
...
@@ -234,7 +348,17 @@ class ShallowWaterSolver(nn.Module):
...
@@ -234,7 +348,17 @@ class ShallowWaterSolver(nn.Module):
def
random_initial_condition
(
self
,
mach
=
0.1
)
->
torch
.
Tensor
:
def
random_initial_condition
(
self
,
mach
=
0.1
)
->
torch
.
Tensor
:
"""
"""
random initial condition on the sphere
Generate random initial condition on the sphere.
Parameters
-----------
mach : float, optional
Mach number for scaling the random perturbations, by default 0.1
Returns
-------
torch.Tensor
Random initial spectral coefficients
"""
"""
device
=
self
.
lap
.
device
device
=
self
.
lap
.
device
ctype
=
torch
.
complex128
if
self
.
lap
.
dtype
==
torch
.
float64
else
torch
.
complex64
ctype
=
torch
.
complex128
if
self
.
lap
.
dtype
==
torch
.
float64
else
torch
.
complex64
...
...
torch_harmonics/examples/stanford_2d3ds_dataset.py
View file @
e4879676
...
@@ -66,13 +66,34 @@ class Stanford2D3DSDownloader:
...
@@ -66,13 +66,34 @@ class Stanford2D3DSDownloader:
"""
"""
def
__init__
(
self
,
base_url
:
str
=
DEFAULT_BASE_URL
,
local_dir
:
str
=
"data"
):
def
__init__
(
self
,
base_url
:
str
=
DEFAULT_BASE_URL
,
local_dir
:
str
=
"data"
):
"""
Initialize the Stanford 2D3DS dataset downloader.
Parameters
-----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
"""
self
.
base_url
=
base_url
self
.
base_url
=
base_url
self
.
local_dir
=
local_dir
self
.
local_dir
=
local_dir
os
.
makedirs
(
self
.
local_dir
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
local_dir
,
exist_ok
=
True
)
def
_download_file
(
self
,
filename
):
def
_download_file
(
self
,
filename
):
"""
Download a single file with progress bar and resume capability.
Parameters
-----------
filename : str
Name of the file to download
Returns
-------
str
Local path to the downloaded file
"""
import
requests
import
requests
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -106,6 +127,19 @@ class Stanford2D3DSDownloader:
...
@@ -106,6 +127,19 @@ class Stanford2D3DSDownloader:
return
local_path
return
local_path
def
_extract_tar
(
self
,
tar_path
):
def
_extract_tar
(
self
,
tar_path
):
"""
Extract a tar file and return the extracted directory name.
Parameters
-----------
tar_path : str
Path to the tar file to extract
Returns
-------
str
Name of the extracted directory
"""
import
tarfile
import
tarfile
with
tarfile
.
open
(
tar_path
)
as
tar
:
with
tarfile
.
open
(
tar_path
)
as
tar
:
...
@@ -116,7 +150,20 @@ class Stanford2D3DSDownloader:
...
@@ -116,7 +150,20 @@ class Stanford2D3DSDownloader:
return
extracted_dir
return
extracted_dir
def
download_dataset
(
self
,
file_extracted_directory_pairs
=
DEFAULT_TAR_FILE_PAIRS
):
def
download_dataset
(
self
,
file_extracted_directory_pairs
=
DEFAULT_TAR_FILE_PAIRS
):
"""
Download and extract the complete dataset.
Parameters
-----------
file_extracted_directory_pairs : list, optional
List of (filename, extracted_folder_name) pairs, by default DEFAULT_TAR_FILE_PAIRS
Returns
-------
tuple
(data_folders, class_labels) where data_folders is a list of extracted directory names
and class_labels is the semantic label mapping
"""
import
requests
import
requests
data_folders
=
[]
data_folders
=
[]
...
@@ -133,6 +180,23 @@ class Stanford2D3DSDownloader:
...
@@ -133,6 +180,23 @@ class Stanford2D3DSDownloader:
return
data_folders
,
class_labels
return
data_folders
,
class_labels
def
_rgb_to_id
(
self
,
img
,
class_labels_map
,
class_labels_indices
):
def
_rgb_to_id
(
self
,
img
,
class_labels_map
,
class_labels_indices
):
"""
Convert RGB image to class ID using color mapping.
Parameters
-----------
img : numpy.ndarray
RGB image array
class_labels_map : list
Mapping from color values to class labels
class_labels_indices : list
List of class label indices
Returns
-------
numpy.ndarray
Class ID array
"""
# Convert to int32 first to avoid overflow
# Convert to int32 first to avoid overflow
r
=
img
[...,
0
].
astype
(
np
.
int32
)
r
=
img
[...,
0
].
astype
(
np
.
int32
)
g
=
img
[...,
1
].
astype
(
np
.
int32
)
g
=
img
[...,
1
].
astype
(
np
.
int32
)
...
@@ -167,7 +231,35 @@ class Stanford2D3DSDownloader:
...
@@ -167,7 +231,35 @@ class Stanford2D3DSDownloader:
downsampling_factor
:
int
=
16
,
downsampling_factor
:
int
=
16
,
remove_alpha_channel
:
bool
=
True
,
remove_alpha_channel
:
bool
=
True
,
):
):
"""
Convert the downloaded dataset to HDF5 format for efficient loading.
Parameters
-----------
data_folders : list
List of extracted data folder names
class_labels : list
List of semantic class labels
rgb_path : str, optional
Relative path to RGB images within each data folder, by default "pano/rgb"
semantic_path : str, optional
Relative path to semantic labels within each data folder, by default "pano/semantic"
depth_path : str, optional
Relative path to depth images within each data folder, by default "pano/depth"
output_filename : str, optional
Suffix for semantic label files, by default "semantic"
dataset_file : str, optional
Output HDF5 filename, by default "stanford_2d3ds_dataset.h5"
downsampling_factor : int, optional
Factor by which to downsample images, by default 16
remove_alpha_channel : bool, optional
Whether to remove alpha channel from RGB images, by default True
Returns
-------
str
Path to the created HDF5 dataset file
"""
converted_dataset_path
=
os
.
path
.
join
(
self
.
local_dir
,
dataset_file
)
converted_dataset_path
=
os
.
path
.
join
(
self
.
local_dir
,
dataset_file
)
from
PIL
import
Image
from
PIL
import
Image
...
...
torch_harmonics/filter_basis.py
View file @
e4879676
...
@@ -62,12 +62,23 @@ class FilterBasis(metaclass=abc.ABCMeta):
...
@@ -62,12 +62,23 @@ class FilterBasis(metaclass=abc.ABCMeta):
self
,
self
,
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
int
,
int
]],
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
int
,
int
]],
):
):
"""
Initialize the filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
self
.
kernel_shape
=
kernel_shape
self
.
kernel_shape
=
kernel_shape
@
property
@
property
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
kernel_size
(
self
):
def
kernel_size
(
self
):
"""
Abstract property that should return the size of the kernel.
Returns:
int: the kernel size
"""
raise
NotImplementedError
raise
NotImplementedError
# @abc.abstractmethod
# @abc.abstractmethod
...
@@ -109,7 +120,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
...
@@ -109,7 +120,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
self
,
self
,
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
int
,
int
]],
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
int
,
int
]],
):
):
"""
Initialize the piecewise linear filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if
isinstance
(
kernel_shape
,
int
):
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
]
kernel_shape
=
[
kernel_shape
]
if
len
(
kernel_shape
)
==
1
:
if
len
(
kernel_shape
)
==
1
:
...
@@ -121,6 +137,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
...
@@ -121,6 +137,12 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@
property
@
property
def
kernel_size
(
self
):
def
kernel_size
(
self
):
"""
Compute the kernel size for piecewise linear basis.
Returns:
int: the kernel size
"""
return
(
self
.
kernel_shape
[
0
]
//
2
)
*
self
.
kernel_shape
[
1
]
+
self
.
kernel_shape
[
0
]
%
2
return
(
self
.
kernel_shape
[
0
]
//
2
)
*
self
.
kernel_shape
[
1
]
+
self
.
kernel_shape
[
0
]
%
2
def
_compute_support_vals_isotropic
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
):
def
_compute_support_vals_isotropic
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
):
...
@@ -225,7 +247,12 @@ class MorletFilterBasis(FilterBasis):
...
@@ -225,7 +247,12 @@ class MorletFilterBasis(FilterBasis):
self
,
self
,
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
int
,
int
]],
kernel_shape
:
Union
[
int
,
Tuple
[
int
],
Tuple
[
int
,
int
]],
):
):
"""
Initialize the Morlet filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if
isinstance
(
kernel_shape
,
int
):
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
,
kernel_shape
]
kernel_shape
=
[
kernel_shape
,
kernel_shape
]
if
len
(
kernel_shape
)
!=
2
:
if
len
(
kernel_shape
)
!=
2
:
...
@@ -235,12 +262,38 @@ class MorletFilterBasis(FilterBasis):
...
@@ -235,12 +262,38 @@ class MorletFilterBasis(FilterBasis):
@
property
@
property
def
kernel_size
(
self
):
def
kernel_size
(
self
):
"""
Compute the kernel size for Morlet basis.
Returns:
int: the kernel size
"""
return
self
.
kernel_shape
[
0
]
*
self
.
kernel_shape
[
1
]
return
self
.
kernel_shape
[
0
]
*
self
.
kernel_shape
[
1
]
def
gaussian_window
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
def
gaussian_window
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
"""
Compute Gaussian window function.
Parameters:
r: radial distance tensor
width: width parameter of the Gaussian
Returns:
torch.Tensor: Gaussian window values
"""
return
1
/
(
2
*
math
.
pi
*
width
**
2
)
*
torch
.
exp
(
-
0.5
*
r
**
2
/
(
width
**
2
))
return
1
/
(
2
*
math
.
pi
*
width
**
2
)
*
torch
.
exp
(
-
0.5
*
r
**
2
/
(
width
**
2
))
def
hann_window
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
def
hann_window
(
self
,
r
:
torch
.
Tensor
,
width
:
float
=
1.0
):
"""
Compute Hann window function.
Parameters:
r: radial distance tensor
width: width parameter of the Hann window
Returns:
torch.Tensor: Hann window values
"""
return
torch
.
cos
(
0.5
*
torch
.
pi
*
r
/
width
)
**
2
return
torch
.
cos
(
0.5
*
torch
.
pi
*
r
/
width
)
**
2
def
compute_support_vals
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
,
width
:
float
=
1.0
):
def
compute_support_vals
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
,
width
:
float
=
1.0
):
...
@@ -282,7 +335,12 @@ class ZernikeFilterBasis(FilterBasis):
...
@@ -282,7 +335,12 @@ class ZernikeFilterBasis(FilterBasis):
self
,
self
,
kernel_shape
:
Union
[
int
,
Tuple
[
int
]],
kernel_shape
:
Union
[
int
,
Tuple
[
int
]],
):
):
"""
Initialize the Zernike filter basis.
Parameters:
kernel_shape: shape of the kernel, can be an integer or tuple of integers
"""
if
isinstance
(
kernel_shape
,
tuple
)
or
isinstance
(
kernel_shape
,
list
):
if
isinstance
(
kernel_shape
,
tuple
)
or
isinstance
(
kernel_shape
,
list
):
kernel_shape
=
kernel_shape
[
0
]
kernel_shape
=
kernel_shape
[
0
]
if
not
isinstance
(
kernel_shape
,
int
):
if
not
isinstance
(
kernel_shape
,
int
):
...
@@ -292,9 +350,26 @@ class ZernikeFilterBasis(FilterBasis):
...
@@ -292,9 +350,26 @@ class ZernikeFilterBasis(FilterBasis):
@
property
@
property
def
kernel_size
(
self
):
def
kernel_size
(
self
):
"""
Compute the kernel size for Zernike basis.
Returns:
int: the kernel size
"""
return
(
self
.
kernel_shape
*
(
self
.
kernel_shape
+
1
))
//
2
return
(
self
.
kernel_shape
*
(
self
.
kernel_shape
+
1
))
//
2
def
zernikeradial
(
self
,
r
:
torch
.
Tensor
,
n
:
torch
.
Tensor
,
m
:
torch
.
Tensor
):
def
zernikeradial
(
self
,
r
:
torch
.
Tensor
,
n
:
torch
.
Tensor
,
m
:
torch
.
Tensor
):
"""
Compute radial Zernike polynomials.
Parameters:
r: radial distance tensor
n: principal quantum number
m: azimuthal quantum number
Returns:
torch.Tensor: radial Zernike polynomial values
"""
out
=
torch
.
zeros_like
(
r
)
out
=
torch
.
zeros_like
(
r
)
bound
=
(
n
-
m
)
//
2
+
1
bound
=
(
n
-
m
)
//
2
+
1
max_bound
=
bound
.
max
().
item
()
max_bound
=
bound
.
max
().
item
()
...
@@ -307,6 +382,18 @@ class ZernikeFilterBasis(FilterBasis):
...
@@ -307,6 +382,18 @@ class ZernikeFilterBasis(FilterBasis):
return
out
return
out
def
zernikepoly
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
:
torch
.
Tensor
,
l
:
torch
.
Tensor
):
def
zernikepoly
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
:
torch
.
Tensor
,
l
:
torch
.
Tensor
):
"""
Compute Zernike polynomials.
Parameters:
r: radial distance tensor
phi: azimuthal angle tensor
n: principal quantum number
l: azimuthal quantum number
Returns:
torch.Tensor: Zernike polynomial values
"""
m
=
2
*
l
-
n
m
=
2
*
l
-
n
return
torch
.
where
(
m
<
0
,
self
.
zernikeradial
(
r
,
n
,
-
m
)
*
torch
.
sin
(
m
*
phi
),
self
.
zernikeradial
(
r
,
n
,
m
)
*
torch
.
cos
(
m
*
phi
))
return
torch
.
where
(
m
<
0
,
self
.
zernikeradial
(
r
,
n
,
-
m
)
*
torch
.
sin
(
m
*
phi
),
self
.
zernikeradial
(
r
,
n
,
m
)
*
torch
.
cos
(
m
*
phi
))
...
...
torch_harmonics/plotting.py
View file @
e4879676
...
@@ -47,6 +47,14 @@ except ImportError as err:
...
@@ -47,6 +47,14 @@ except ImportError as err:
def
check_plotting_dependencies
():
def
check_plotting_dependencies
():
"""
Check if required plotting dependencies (matplotlib and cartopy) are available.
Raises
------
ImportError
If matplotlib or cartopy is not installed
"""
if
plt
is
None
:
if
plt
is
None
:
raise
ImportError
(
"matplotlib is required for plotting functions. Install it with 'pip install matplotlib'"
)
raise
ImportError
(
"matplotlib is required for plotting functions. Install it with 'pip install matplotlib'"
)
if
cartopy
is
None
:
if
cartopy
is
None
:
...
@@ -58,6 +66,28 @@ def get_projection(
...
@@ -58,6 +66,28 @@ def get_projection(
central_latitude
=
0
,
central_latitude
=
0
,
central_longitude
=
0
,
central_longitude
=
0
,
):
):
"""
Get a cartopy projection object for map plotting.
Parameters
-----------
projection : str
Projection type ("orthographic", "robinson", "platecarree", "mollweide")
central_latitude : float, optional
Central latitude for the projection, by default 0
central_longitude : float, optional
Central longitude for the projection, by default 0
Returns
-------
cartopy.crs.Projection
Cartopy projection object
Raises
------
ValueError
If projection type is not supported
"""
if
projection
==
"orthographic"
:
if
projection
==
"orthographic"
:
proj
=
ccrs
.
Orthographic
(
central_latitude
=
central_latitude
,
central_longitude
=
central_longitude
)
proj
=
ccrs
.
Orthographic
(
central_latitude
=
central_latitude
,
central_longitude
=
central_longitude
)
elif
projection
==
"robinson"
:
elif
projection
==
"robinson"
:
...
@@ -77,6 +107,40 @@ def plot_sphere(
...
@@ -77,6 +107,40 @@ def plot_sphere(
):
):
"""
"""
Plots a function defined on the sphere using pcolormesh
Plots a function defined on the sphere using pcolormesh
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to plot with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
cmap : str, optional
Colormap name, by default "RdBu"
title : str, optional
Plot title, by default None
colorbar : bool, optional
Whether to add a colorbar, by default False
coastlines : bool, optional
Whether to add coastlines, by default False
gridlines : bool, optional
Whether to add gridlines, by default False
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
lon : numpy.ndarray, optional
Longitude coordinates, by default None (auto-generated)
lat : numpy.ndarray, optional
Latitude coordinates, by default None (auto-generated)
**kwargs
Additional arguments passed to pcolormesh
Returns
-------
matplotlib.collections.QuadMesh
The plotted image object
"""
"""
# make sure cartopy exist
# make sure cartopy exist
...
@@ -126,6 +190,28 @@ def plot_sphere(
...
@@ -126,6 +190,28 @@ def plot_sphere(
def
imshow_sphere
(
data
,
fig
=
None
,
projection
=
"robinson"
,
title
=
None
,
central_latitude
=
0
,
central_longitude
=
0
,
**
kwargs
):
def
imshow_sphere
(
data
,
fig
=
None
,
projection
=
"robinson"
,
title
=
None
,
central_latitude
=
0
,
central_longitude
=
0
,
**
kwargs
):
"""
"""
Displays an image on the sphere
Displays an image on the sphere
Parameters
-----------
data : numpy.ndarray or torch.Tensor
Data to display with shape (nlat, nlon)
fig : matplotlib.figure.Figure, optional
Figure to plot on, by default None (creates new figure)
projection : str, optional
Map projection type, by default "robinson"
title : str, optional
Plot title, by default None
central_latitude : float, optional
Central latitude for projection, by default 0
central_longitude : float, optional
Central longitude for projection, by default 0
**kwargs
Additional arguments passed to imshow
Returns
-------
matplotlib.image.AxesImage
The displayed image object
"""
"""
# make sure cartopy exist
# make sure cartopy exist
...
...
torch_harmonics/quadrature.py
View file @
e4879676
...
@@ -37,6 +37,32 @@ import torch
...
@@ -37,6 +37,32 @@ import torch
def
_precompute_grid
(
n
:
int
,
grid
:
Optional
[
str
]
=
"equidistant"
,
a
:
Optional
[
float
]
=
0.0
,
b
:
Optional
[
float
]
=
1.0
,
def
_precompute_grid
(
n
:
int
,
grid
:
Optional
[
str
]
=
"equidistant"
,
a
:
Optional
[
float
]
=
0.0
,
b
:
Optional
[
float
]
=
1.0
,
periodic
:
Optional
[
bool
]
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
periodic
:
Optional
[
bool
]
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
Precompute grid points and weights for various quadrature rules.
Parameters
-----------
n : int
Number of grid points
grid : str, optional
Grid type ("equidistant", "legendre-gauss", "lobatto", "equiangular"), by default "equidistant"
a : float, optional
Lower bound of interval, by default 0.0
b : float, optional
Upper bound of interval, by default 1.0
periodic : bool, optional
Whether the grid is periodic (only for equidistant), by default False
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Grid points and weights
Raises
------
ValueError
If periodic is True for non-equidistant grids or unknown grid type
"""
if
(
grid
!=
"equidistant"
)
and
periodic
:
if
(
grid
!=
"equidistant"
)
and
periodic
:
raise
ValueError
(
f
"Periodic grid is only supported on equidistant grids."
)
raise
ValueError
(
f
"Periodic grid is only supported on equidistant grids."
)
...
...
torch_harmonics/resample.py
View file @
e4879676
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment