Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
a8f2af6c
Commit
a8f2af6c
authored
Jul 17, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
updated docstring
parent
42067ef2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
170 deletions
+13
-170
torch_harmonics/examples/shallow_water_equations.py
torch_harmonics/examples/shallow_water_equations.py
+13
-170
No files found.
torch_harmonics/examples/shallow_water_equations.py
View file @
a8f2af6c
...
...
@@ -142,139 +142,43 @@ class ShallowWaterSolver(nn.Module):
self
.
register_buffer
(
'quad_weights'
,
quad_weights
)
def
grid2spec
(
self
,
ugrid
):
"""
Convert spatial data to spectral coefficients.
Parameters
-----------
ugrid : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
"""Convert spatial data to spectral coefficients."""
return
self
.
sht
(
ugrid
)
def
spec2grid
(
self
,
uspec
):
"""
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
"""Convert spectral coefficients to spatial data."""
return
self
.
isht
(
uspec
)
def
vrtdivspec
(
self
,
ugrid
):
"""
Compute vorticity and divergence from velocity field.
Parameters
-----------
ugrid : torch.Tensor
Velocity field in spatial coordinates
Returns
-------
torch.Tensor
Spectral coefficients of vorticity and divergence
"""
"""Compute vorticity and divergence from velocity field."""
vrtdivspec
=
self
.
lap
*
self
.
radius
*
self
.
vsht
(
ugrid
)
return
vrtdivspec
def
getuv
(
self
,
vrtdivspec
):
"""
Compute wind vector from spectral coefficients of vorticity and divergence.
Parameters
-----------
vrtdivspec : torch.Tensor
Spectral coefficients of vorticity and divergence
Returns
-------
torch.Tensor
Wind vector field in spatial coordinates
"""
"""Compute wind vector from spectral coefficients of vorticity and divergence."""
return
self
.
ivsht
(
self
.
invlap
*
vrtdivspec
/
self
.
radius
)
def
gethuv
(
self
,
uspec
):
"""
Compute height and wind vector from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Combined height and wind vector field
"""
"""Compute height and wind vector from spectral coefficients."""
hgrid
=
self
.
spec2grid
(
uspec
[:
1
])
uvgrid
=
self
.
getuv
(
uspec
[
1
:])
return
torch
.
cat
((
hgrid
,
uvgrid
),
dim
=-
3
)
def
potential_vorticity
(
self
,
uspec
):
"""
Compute potential vorticity from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Potential vorticity field
"""
"""Compute potential vorticity from spectral coefficients."""
ugrid
=
self
.
spec2grid
(
uspec
)
pvrt
=
(
0.5
*
self
.
havg
*
self
.
gravity
/
self
.
omega
)
*
(
ugrid
[
1
]
+
self
.
coriolis
)
/
ugrid
[
0
]
return
pvrt
def
dimensionless
(
self
,
uspec
):
"""
Remove dimensions from variables for dimensionless analysis.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients with dimensions
Returns
-------
torch.Tensor
Dimensionless spectral coefficients
"""
"""Remove dimensions from variables for dimensionless analysis."""
uspec
[
0
]
=
(
uspec
[
0
]
-
self
.
havg
*
self
.
gravity
)
/
self
.
hamp
/
self
.
gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
uspec
[
1
:]
=
uspec
[
1
:]
*
self
.
radius
/
torch
.
sqrt
(
self
.
gravity
*
self
.
havg
)
return
uspec
def
dudtspec
(
self
,
uspec
):
"""
Compute time derivatives from solution represented in spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Time derivatives of spectral coefficients
"""
"""Compute time derivatives from solution represented in spectral coefficients."""
dudtspec
=
torch
.
zeros_like
(
uspec
)
# compute the derivatives - this should be incorporated into the solver:
...
...
@@ -299,23 +203,7 @@ class ShallowWaterSolver(nn.Module):
return
dudtspec
def
galewsky_initial_condition
(
self
):
"""
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Parameters
----------
None
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
"""Initialize non-linear barotropically unstable shallow water test case."""
device
=
self
.
lap
.
device
umax
=
80.
...
...
@@ -353,19 +241,7 @@ class ShallowWaterSolver(nn.Module):
return
torch
.
tril
(
uspec
)
def
random_initial_condition
(
self
,
mach
=
0.1
)
->
torch
.
Tensor
:
"""
Generate random initial condition on the sphere.
Parameters
-----------
mach : float, optional
Mach number for scaling the random perturbations, by default 0.1
Returns
-------
torch.Tensor
Random initial spectral coefficients
"""
"""Generate random initial condition on the sphere."""
device
=
self
.
lap
.
device
ctype
=
torch
.
complex128
if
self
.
lap
.
dtype
==
torch
.
float64
else
torch
.
complex64
...
...
@@ -411,22 +287,7 @@ class ShallowWaterSolver(nn.Module):
return
torch
.
tril
(
uspec
)
def
timestep
(
self
,
uspec
:
torch
.
Tensor
,
nsteps
:
int
)
->
torch
.
Tensor
:
"""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
"""
"""Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps."""
dudtspec
=
torch
.
zeros
(
3
,
3
,
self
.
lmax
,
self
.
mmax
,
dtype
=
uspec
.
dtype
,
device
=
uspec
.
device
)
# pointers to indicate the most current result
...
...
@@ -458,23 +319,7 @@ class ShallowWaterSolver(nn.Module):
return
uspec
def
integrate_grid
(
self
,
ugrid
,
dimensionless
=
False
,
polar_opt
=
0
):
"""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
"""Integrate the solution on the grid."""
dlon
=
2
*
torch
.
pi
/
self
.
nlon
radius
=
1
if
dimensionless
else
self
.
radius
if
polar_opt
>
0
:
...
...
@@ -485,9 +330,7 @@ class ShallowWaterSolver(nn.Module):
def
plot_griddata
(
self
,
data
,
fig
,
cmap
=
'twilight_shifted'
,
vmax
=
None
,
vmin
=
None
,
projection
=
'3d'
,
title
=
None
,
antialiased
=
False
):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
"""Plotting routine for data on the grid. Requires cartopy for 3d plots."""
import
matplotlib.pyplot
as
plt
lons
=
self
.
lons
.
squeeze
()
-
torch
.
pi
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment