Commit a8f2af6c authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

updated docstring

parent 42067ef2
...@@ -142,139 +142,43 @@ class ShallowWaterSolver(nn.Module): ...@@ -142,139 +142,43 @@ class ShallowWaterSolver(nn.Module):
self.register_buffer('quad_weights', quad_weights) self.register_buffer('quad_weights', quad_weights)
def grid2spec(self, ugrid): def grid2spec(self, ugrid):
""" """Convert spatial data to spectral coefficients."""
Convert spatial data to spectral coefficients.
Parameters
-----------
ugrid : torch.Tensor
Spatial data tensor
Returns
-------
torch.Tensor
Spectral coefficients
"""
return self.sht(ugrid) return self.sht(ugrid)
def spec2grid(self, uspec): def spec2grid(self, uspec):
""" """Convert spectral coefficients to spatial data."""
Convert spectral coefficients to spatial data.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients tensor
Returns
-------
torch.Tensor
Spatial data
"""
return self.isht(uspec) return self.isht(uspec)
def vrtdivspec(self, ugrid): def vrtdivspec(self, ugrid):
""" """Compute vorticity and divergence from velocity field."""
Compute vorticity and divergence from velocity field.
Parameters
-----------
ugrid : torch.Tensor
Velocity field in spatial coordinates
Returns
-------
torch.Tensor
Spectral coefficients of vorticity and divergence
"""
vrtdivspec = self.lap * self.radius * self.vsht(ugrid) vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
return vrtdivspec return vrtdivspec
def getuv(self, vrtdivspec): def getuv(self, vrtdivspec):
""" """Compute wind vector from spectral coefficients of vorticity and divergence."""
Compute wind vector from spectral coefficients of vorticity and divergence.
Parameters
-----------
vrtdivspec : torch.Tensor
Spectral coefficients of vorticity and divergence
Returns
-------
torch.Tensor
Wind vector field in spatial coordinates
"""
return self.ivsht( self.invlap * vrtdivspec / self.radius) return self.ivsht( self.invlap * vrtdivspec / self.radius)
def gethuv(self, uspec): def gethuv(self, uspec):
""" """Compute height and wind vector from spectral coefficients."""
Compute height and wind vector from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Combined height and wind vector field
"""
hgrid = self.spec2grid(uspec[:1]) hgrid = self.spec2grid(uspec[:1])
uvgrid = self.getuv(uspec[1:]) uvgrid = self.getuv(uspec[1:])
return torch.cat((hgrid, uvgrid), dim=-3) return torch.cat((hgrid, uvgrid), dim=-3)
def potential_vorticity(self, uspec): def potential_vorticity(self, uspec):
""" """Compute potential vorticity from spectral coefficients."""
Compute potential vorticity from spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Potential vorticity field
"""
ugrid = self.spec2grid(uspec) ugrid = self.spec2grid(uspec)
pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0] pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
return pvrt return pvrt
def dimensionless(self, uspec): def dimensionless(self, uspec):
""" """Remove dimensions from variables for dimensionless analysis."""
Remove dimensions from variables for dimensionless analysis.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients with dimensions
Returns
-------
torch.Tensor
Dimensionless spectral coefficients
"""
uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r # vorticity is measured in 1/s so we normalize using sqrt(g h) / r
uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg) uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg)
return uspec return uspec
def dudtspec(self, uspec): def dudtspec(self, uspec):
""" """Compute time derivatives from solution represented in spectral coefficients."""
Compute time derivatives from solution represented in spectral coefficients.
Parameters
-----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
Returns
-------
torch.Tensor
Time derivatives of spectral coefficients
"""
dudtspec = torch.zeros_like(uspec) dudtspec = torch.zeros_like(uspec)
# compute the derivatives - this should be incorporated into the solver: # compute the derivatives - this should be incorporated into the solver:
...@@ -299,23 +203,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -299,23 +203,7 @@ class ShallowWaterSolver(nn.Module):
return dudtspec return dudtspec
def galewsky_initial_condition(self): def galewsky_initial_condition(self):
""" """Initialize non-linear barotropically unstable shallow water test case."""
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
Parameters
----------
None
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
device = self.lap.device device = self.lap.device
umax = 80. umax = 80.
...@@ -353,19 +241,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -353,19 +241,7 @@ class ShallowWaterSolver(nn.Module):
return torch.tril(uspec) return torch.tril(uspec)
def random_initial_condition(self, mach=0.1) -> torch.Tensor: def random_initial_condition(self, mach=0.1) -> torch.Tensor:
""" """Generate random initial condition on the sphere."""
Generate random initial condition on the sphere.
Parameters
-----------
mach : float, optional
Mach number for scaling the random perturbations, by default 0.1
Returns
-------
torch.Tensor
Random initial spectral coefficients
"""
device = self.lap.device device = self.lap.device
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64 ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
...@@ -411,22 +287,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -411,22 +287,7 @@ class ShallowWaterSolver(nn.Module):
return torch.tril(uspec) return torch.tril(uspec)
def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor: def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
""" """Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps."""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
"""
dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device) dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)
# pointers to indicate the most current result # pointers to indicate the most current result
...@@ -458,23 +319,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -458,23 +319,7 @@ class ShallowWaterSolver(nn.Module):
return uspec return uspec
def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0): def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
""" """Integrate the solution on the grid."""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
dlon = 2 * torch.pi / self.nlon dlon = 2 * torch.pi / self.nlon
radius = 1 if dimensionless else self.radius radius = 1 if dimensionless else self.radius
if polar_opt > 0: if polar_opt > 0:
...@@ -485,9 +330,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -485,9 +330,7 @@ class ShallowWaterSolver(nn.Module):
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
""" """Plotting routine for data on the grid. Requires cartopy for 3d plots."""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
lons = self.lons.squeeze() - torch.pi lons = self.lons.squeeze() - torch.pi
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment