Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
a8f2af6c
Commit
a8f2af6c
authored
Jul 17, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
updated docstring
parent
42067ef2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
170 deletions
+13
-170
torch_harmonics/examples/shallow_water_equations.py
torch_harmonics/examples/shallow_water_equations.py
+13
-170
No files found.
torch_harmonics/examples/shallow_water_equations.py
View file @
a8f2af6c
...
@@ -142,139 +142,43 @@ class ShallowWaterSolver(nn.Module):
...
@@ -142,139 +142,43 @@ class ShallowWaterSolver(nn.Module):
self
.
register_buffer
(
'quad_weights'
,
quad_weights
)
self
.
register_buffer
(
'quad_weights'
,
quad_weights
)
def
grid2spec
(
self
,
ugrid
):
def
grid2spec
(
self
,
ugrid
):
"""
"""Convert spatial data to spectral coefficients."""
Convert spatial data to spectral coefficients.
Parameters
-----------
ugrid : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return
self
.
sht
(
ugrid
)
return
self
.
sht
(
ugrid
)
def
spec2grid
(
self
,
uspec
):
def
spec2grid
(
self
,
uspec
):
"""
"""Convert spectral coefficients to spatial data."""
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return
self
.
isht
(
uspec
)
return
self
.
isht
(
uspec
)
def
vrtdivspec
(
self
,
ugrid
):
def
vrtdivspec
(
self
,
ugrid
):
"""
"""Compute vorticity and divergence from velocity field."""
Compute vorticity and divergence from velocity field.
Parameters
-----------
ugrid : torch.Tensor
Velocity field in spatial coordinates
Returns
-------
torch.Tensor
Spectral coefficients of vorticity and divergence
"""
vrtdivspec
=
self
.
lap
*
self
.
radius
*
self
.
vsht
(
ugrid
)
vrtdivspec
=
self
.
lap
*
self
.
radius
*
self
.
vsht
(
ugrid
)
return
vrtdivspec
return
vrtdivspec
def
getuv
(
self
,
vrtdivspec
):
def
getuv
(
self
,
vrtdivspec
):
"""
"""Compute wind vector from spectral coefficients of vorticity and divergence."""
Compute wind vector from spectral coefficients of vorticity and divergence.
Parameters
-----------
vrtdivspec : torch.Tensor
Spectral coefficients of vorticity and divergence
Returns
-------
torch.Tensor
Wind vector field in spatial coordinates
"""
return
self
.
ivsht
(
self
.
invlap
*
vrtdivspec
/
self
.
radius
)
return
self
.
ivsht
(
self
.
invlap
*
vrtdivspec
/
self
.
radius
)
def
gethuv
(
self
,
uspec
):
def
gethuv
(
self
,
uspec
):
"""
"""Compute height and wind vector from spectral coefficients."""
Compute height and wind vector from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Combined height and wind vector field
"""
hgrid
=
self
.
spec2grid
(
uspec
[:
1
])
hgrid
=
self
.
spec2grid
(
uspec
[:
1
])
uvgrid
=
self
.
getuv
(
uspec
[
1
:])
uvgrid
=
self
.
getuv
(
uspec
[
1
:])
return
torch
.
cat
((
hgrid
,
uvgrid
),
dim
=-
3
)
return
torch
.
cat
((
hgrid
,
uvgrid
),
dim
=-
3
)
def
potential_vorticity
(
self
,
uspec
):
def
potential_vorticity
(
self
,
uspec
):
"""
"""Compute potential vorticity from spectral coefficients."""
Compute potential vorticity from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Potential vorticity field
"""
ugrid
=
self
.
spec2grid
(
uspec
)
ugrid
=
self
.
spec2grid
(
uspec
)
pvrt
=
(
0.5
*
self
.
havg
*
self
.
gravity
/
self
.
omega
)
*
(
ugrid
[
1
]
+
self
.
coriolis
)
/
ugrid
[
0
]
pvrt
=
(
0.5
*
self
.
havg
*
self
.
gravity
/
self
.
omega
)
*
(
ugrid
[
1
]
+
self
.
coriolis
)
/
ugrid
[
0
]
return
pvrt
return
pvrt
def
dimensionless
(
self
,
uspec
):
def
dimensionless
(
self
,
uspec
):
"""
"""Remove dimensions from variables for dimensionless analysis."""
Remove dimensions from variables for dimensionless analysis.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients with dimensions
Returns
-------
torch.Tensor
Dimensionless spectral coefficients
"""
uspec
[
0
]
=
(
uspec
[
0
]
-
self
.
havg
*
self
.
gravity
)
/
self
.
hamp
/
self
.
gravity
uspec
[
0
]
=
(
uspec
[
0
]
-
self
.
havg
*
self
.
gravity
)
/
self
.
hamp
/
self
.
gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
uspec
[
1
:]
=
uspec
[
1
:]
*
self
.
radius
/
torch
.
sqrt
(
self
.
gravity
*
self
.
havg
)
uspec
[
1
:]
=
uspec
[
1
:]
*
self
.
radius
/
torch
.
sqrt
(
self
.
gravity
*
self
.
havg
)
return
uspec
return
uspec
def
dudtspec
(
self
,
uspec
):
def
dudtspec
(
self
,
uspec
):
"""
"""Compute time derivatives from solution represented in spectral coefficients."""
Compute time derivatives from solution represented in spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Time derivatives of spectral coefficients
"""
dudtspec
=
torch
.
zeros_like
(
uspec
)
dudtspec
=
torch
.
zeros_like
(
uspec
)
# compute the derivatives - this should be incorporated into the solver:
# compute the derivatives - this should be incorporated into the solver:
...
@@ -299,23 +203,7 @@ class ShallowWaterSolver(nn.Module):
...
@@ -299,23 +203,7 @@ class ShallowWaterSolver(nn.Module):
return
dudtspec
return
dudtspec
def
galewsky_initial_condition
(
self
):
def
galewsky_initial_condition
(
self
):
"""
"""Initialize non-linear barotropically unstable shallow water test case."""
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Parameters
----------
None
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
device
=
self
.
lap
.
device
device
=
self
.
lap
.
device
umax
=
80.
umax
=
80.
...
@@ -353,19 +241,7 @@ class ShallowWaterSolver(nn.Module):
...
@@ -353,19 +241,7 @@ class ShallowWaterSolver(nn.Module):
return
torch
.
tril
(
uspec
)
return
torch
.
tril
(
uspec
)
def
random_initial_condition
(
self
,
mach
=
0.1
)
->
torch
.
Tensor
:
def
random_initial_condition
(
self
,
mach
=
0.1
)
->
torch
.
Tensor
:
"""
"""Generate random initial condition on the sphere."""
Generate random initial condition on the sphere.
Parameters
-----------
mach : float, optional
Mach number for scaling the random perturbations, by default 0.1
Returns
-------
torch.Tensor
Random initial spectral coefficients
"""
device
=
self
.
lap
.
device
device
=
self
.
lap
.
device
ctype
=
torch
.
complex128
if
self
.
lap
.
dtype
==
torch
.
float64
else
torch
.
complex64
ctype
=
torch
.
complex128
if
self
.
lap
.
dtype
==
torch
.
float64
else
torch
.
complex64
...
@@ -411,22 +287,7 @@ class ShallowWaterSolver(nn.Module):
...
@@ -411,22 +287,7 @@ class ShallowWaterSolver(nn.Module):
return
torch
.
tril
(
uspec
)
return
torch
.
tril
(
uspec
)
def
timestep
(
self
,
uspec
:
torch
.
Tensor
,
nsteps
:
int
)
->
torch
.
Tensor
:
def
timestep
(
self
,
uspec
:
torch
.
Tensor
,
nsteps
:
int
)
->
torch
.
Tensor
:
"""
"""Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps."""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
"""
dudtspec
=
torch
.
zeros
(
3
,
3
,
self
.
lmax
,
self
.
mmax
,
dtype
=
uspec
.
dtype
,
device
=
uspec
.
device
)
dudtspec
=
torch
.
zeros
(
3
,
3
,
self
.
lmax
,
self
.
mmax
,
dtype
=
uspec
.
dtype
,
device
=
uspec
.
device
)
# pointers to indicate the most current result
# pointers to indicate the most current result
...
@@ -458,23 +319,7 @@ class ShallowWaterSolver(nn.Module):
...
@@ -458,23 +319,7 @@ class ShallowWaterSolver(nn.Module):
return
uspec
return
uspec
def
integrate_grid
(
self
,
ugrid
,
dimensionless
=
False
,
polar_opt
=
0
):
def
integrate_grid
(
self
,
ugrid
,
dimensionless
=
False
,
polar_opt
=
0
):
"""
"""Integrate the solution on the grid."""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
dlon
=
2
*
torch
.
pi
/
self
.
nlon
dlon
=
2
*
torch
.
pi
/
self
.
nlon
radius
=
1
if
dimensionless
else
self
.
radius
radius
=
1
if
dimensionless
else
self
.
radius
if
polar_opt
>
0
:
if
polar_opt
>
0
:
...
@@ -485,9 +330,7 @@ class ShallowWaterSolver(nn.Module):
...
@@ -485,9 +330,7 @@ class ShallowWaterSolver(nn.Module):
def
plot_griddata
(
self
,
data
,
fig
,
cmap
=
'twilight_shifted'
,
vmax
=
None
,
vmin
=
None
,
projection
=
'3d'
,
title
=
None
,
antialiased
=
False
):
def
plot_griddata
(
self
,
data
,
fig
,
cmap
=
'twilight_shifted'
,
vmax
=
None
,
vmin
=
None
,
projection
=
'3d'
,
title
=
None
,
antialiased
=
False
):
"""
"""Plotting routine for data on the grid. Requires cartopy for 3d plots."""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
lons
=
self
.
lons
.
squeeze
()
-
torch
.
pi
lons
=
self
.
lons
.
squeeze
()
-
torch
.
pi
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment