Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
import functools
import inspect
import os
import sys
import importlib
import click
class VerosSetting(click.ParamType):
name = "setting"
current_key = None
def convert(self, value, param, ctx):
from veros.settings import SETTINGS
assert param.nargs == 2
if self.current_key is None:
if value not in SETTINGS:
self.fail(f"Unknown setting {value}")
self.current_key = value
return value
assert self.current_key in SETTINGS
setting = SETTINGS[self.current_key]
self.current_key = None
if setting.type is bool:
return click.BOOL(value)
return setting.type(value)
def _import_from_file(path):
module = os.path.basename(path).split(".py")[0]
spec = importlib.util.spec_from_file_location(module, path)
mod = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = mod
spec.loader.exec_module(mod)
return mod
def run(setup_file, *args, **kwargs):
"""Runs a Veros setup from given file"""
from veros import runtime_settings, VerosSetup, __version__ as veros_version
kwargs["override"] = dict(kwargs["override"])
runtime_setting_kwargs = (
"backend",
"profile_mode",
"num_proc",
"loglevel",
"device",
"float_type",
"diskless_mode",
"force_overwrite",
)
for setting in runtime_setting_kwargs:
setattr(runtime_settings, setting, kwargs.pop(setting))
runtime_settings.setup_file = setup_file
# determine setup class from given Python file
setup_module = _import_from_file(setup_file)
SetupClass = None
for obj in vars(setup_module).values():
if inspect.isclass(obj) and issubclass(obj, VerosSetup) and obj is not VerosSetup:
if SetupClass is not None and SetupClass is not obj:
raise RuntimeError("Veros setups can only define one VerosSetup class")
SetupClass = obj
from veros import logger
target_version = getattr(setup_module, "__VEROS_VERSION__", None)
if target_version and target_version != veros_version:
logger.warning(
f"This is Veros v{veros_version}, but the given setup was generated with v{target_version}. "
"Consider switching to this version of Veros or updating your setup file.\n"
)
sim = SetupClass(*args, **kwargs)
sim.setup()
sim.run()
@click.command("veros-run")
@click.argument("SETUP_FILE", type=click.Path(readable=True, dir_okay=False, resolve_path=True, exists=True))
@click.option(
"-b",
"--backend",
default="numpy",
type=click.Choice(["numpy", "jax"]),
help="Backend to use for computations",
show_default=True,
)
@click.option(
"--device",
default="cpu",
type=click.Choice(["cpu", "gpu"]),
help="Hardware device to use (JAX backend only)",
show_default=True,
)
@click.option(
"-v",
"--loglevel",
default="info",
type=click.Choice(["trace", "debug", "info", "warning", "error"]),
help="Log level used for output",
show_default=True,
)
@click.option(
"-s",
"--override",
nargs=2,
multiple=True,
metavar="SETTING VALUE",
type=VerosSetting(),
default=tuple(),
help="Override model setting, may be specified multiple times",
)
@click.option(
"-p",
"--profile-mode",
is_flag=True,
default=False,
type=click.BOOL,
envvar="VEROS_PROFILE",
help="Write a performance profile for debugging",
show_default=True,
)
@click.option("--force-overwrite", is_flag=True, help="Silently overwrite existing outputs")
@click.option("--diskless-mode", is_flag=True, help="Supress all output to disk")
@click.option(
"--float-type",
default="float64",
type=click.Choice(["float64", "float32"]),
help="Floating point precision to use",
show_default=True,
)
@click.option(
"-n", "--num-proc", nargs=2, default=[1, 1], type=click.INT, help="Number of processes in x and y dimension"
)
@functools.wraps(run)
def cli(setup_file, *args, **kwargs):
if not setup_file.endswith(".py"):
raise click.UsageError(f"The given setup file {setup_file} does not appear to be a Python file.")
return run(setup_file, *args, **kwargs)
# Veros core modules
## Structure
Modules ending with an underscore (e.g. `petsc_.py`) are *optional* and are not imported automatically.
import os
import importlib
from veros import logger
def build_all():
"""Trigger first import of all core modules"""
from veros import runtime_settings as rs
from veros.backend import BACKEND_MESSAGES, get_curent_device_name
logger.info("Importing core modules")
logger.opt(colors=True).info(
" Using computational backend <bold>{}</bold> on <bold>{}</bold>", rs.backend, get_curent_device_name()
)
extra_message = BACKEND_MESSAGES.get(rs.backend)
if extra_message:
logger.info(" {}", extra_message)
basedir = os.path.dirname(__file__)
for root, dirs, files in os.walk(basedir):
py_path = ".".join(os.path.split(os.path.relpath(root, basedir))).strip(".")
for f in files:
modname, ext = os.path.splitext(f)
if modname.endswith("_") or ext != ".py":
continue
if py_path:
module_path = f"veros.core.{py_path}.{modname}"
else:
module_path = f"veros.core.{modname}"
logger.trace("importing {}", module_path)
try:
importlib.import_module(module_path)
except ImportError:
pass
if not rs.__locked__:
rs.__locked__ = True
logger.info(" Runtime settings are now locked")
logger.info("")
build_all()
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core.utilities import pad_z_edges
from veros.core.operators import numpy as npx, update, update_add, update_multiply, at
@veros_kernel
def _calc_cr(rjp, rj, rjm, vel):
"""
Calculates cr value used in superbee advection scheme
"""
eps = 1e-20 # prevent division by 0
return npx.where(vel > 0.0, rjm, rjp) / npx.where(npx.abs(rj) < eps, eps, rj)
@veros_kernel
def limiter(cr):
return npx.maximum(npx.clip(2 * cr, 0, 1), npx.clip(cr, 0, 2))
@veros_kernel(static_args=("axis"))
def _adv_superbee(state, vel, var, mask, dx, axis):
vs = state.variables
settings = state.settings
if axis == 0:
sm1, s, sp1, sp2 = ((slice(1 + n, -2 + n or None), slice(2, -2), slice(None)) for n in range(-1, 3))
dx = vs.cost[npx.newaxis, 2:-2, npx.newaxis] * dx[1:-2, npx.newaxis, npx.newaxis]
elif axis == 1:
sm1, s, sp1, sp2 = ((slice(2, -2), slice(1 + n, -2 + n or None), slice(None)) for n in range(-1, 3))
dx = (vs.cost * dx)[npx.newaxis, 1:-2, npx.newaxis]
elif axis == 2:
sm1, s, sp1, sp2 = ((slice(2, -2), slice(2, -2), slice(1 + n, -2 + n or None)) for n in range(-1, 3))
dx = dx[npx.newaxis, npx.newaxis, :-1]
vel, var, mask = (pad_z_edges(a) for a in (vel, var, mask))
else:
raise ValueError("axis must be 0, 1, or 2")
rjp = (var[sp2] - var[sp1]) * mask[sp1]
rj = (var[sp1] - var[s]) * mask[s]
rjm = (var[s] - var[sm1]) * mask[sm1]
cr = limiter(_calc_cr(rjp, rj, rjm, vel[s]))
if axis == 1:
vel = vel * vs.cosu[npx.newaxis, :, npx.newaxis]
uCFL = npx.abs(vel[s] * settings.dt_tracer / dx)
return vel[s] * (var[sp1] + var[s]) * 0.5 - npx.abs(vel[s]) * ((1.0 - cr) + uCFL * cr) * rj * 0.5
@veros_kernel
def adv_flux_2nd(state, var):
"""
2nd order advective tracer flux
"""
vs = state.variables
adv_fe = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_fn = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_ft = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_fe = update(
adv_fe,
at[1:-2, 2:-2, :],
0.5 * (var[1:-2, 2:-2, :] + var[2:-1, 2:-2, :]) * vs.u[1:-2, 2:-2, :, vs.tau] * vs.maskU[1:-2, 2:-2, :],
)
adv_fn = update(
adv_fn,
at[2:-2, 1:-2, :],
vs.cosu[npx.newaxis, 1:-2, npx.newaxis]
* 0.5
* (var[2:-2, 1:-2, :] + var[2:-2, 2:-1, :])
* vs.v[2:-2, 1:-2, :, vs.tau]
* vs.maskV[2:-2, 1:-2, :],
)
adv_ft = update(
adv_ft,
at[2:-2, 2:-2, :-1],
0.5 * (var[2:-2, 2:-2, :-1] + var[2:-2, 2:-2, 1:]) * vs.w[2:-2, 2:-2, :-1, vs.tau] * vs.maskW[2:-2, 2:-2, :-1],
)
adv_ft = update(adv_ft, at[:, :, -1], 0.0)
return adv_fe, adv_fn, adv_ft
@veros_kernel
def adv_flux_superbee(state, var):
r"""
from MITgcm
Calculates advection of a tracer
using second-order interpolation with a flux limiter:
\begin{equation*}
F^x_{adv} = U \overline{ \theta }^i
- \frac{1}{2} \left([ 1 - \psi(C_r) ] |U|
+ U \frac{u \Delta t}{\Delta x_c} \psi(C_r)
\right) \delta_i \theta
\end{equation*}
where the $\psi(C_r)$ is the limiter function and $C_r$ is
the slope ratio.
"""
vs = state.variables
adv_fe = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_fn = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_ft = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_fe = update(adv_fe, at[1:-2, 2:-2, :], _adv_superbee(state, vs.u[..., vs.tau], var, vs.maskU, vs.dxt, 0))
adv_fn = update(adv_fn, at[2:-2, 1:-2, :], _adv_superbee(state, vs.v[..., vs.tau], var, vs.maskV, vs.dyt, 1))
adv_ft = update(adv_ft, at[2:-2, 2:-2, :-1], _adv_superbee(state, vs.w[..., vs.tau], var, vs.maskW, vs.dzt, 2))
adv_ft = update(adv_ft, at[..., -1], 0.0)
return adv_fe, adv_fn, adv_ft
@veros_routine
def calculate_velocity_on_wgrid(state):
vs = state.variables
vs.update(calculate_velocity_on_wgrid_kernel(state))
@veros_kernel
def calculate_velocity_on_wgrid_kernel(state):
"""
calculates advection velocity for tracer on vs.W grid
Note: this implementation is not strictly equal to the Fortran version. They only match
if vs.maskW has exactly one true value across each depth slice.
"""
vs = state.variables
# lateral advection velocities on W grid
vs.u_wgrid = update(
vs.u_wgrid,
at[:, :, :-1],
vs.u[:, :, 1:, vs.tau]
* vs.maskU[:, :, 1:]
* 0.5
* vs.dzt[npx.newaxis, npx.newaxis, 1:]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
+ vs.u[:, :, :-1, vs.tau]
* vs.maskU[:, :, :-1]
* 0.5
* vs.dzt[npx.newaxis, npx.newaxis, :-1]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
vs.v_wgrid = update(
vs.v_wgrid,
at[:, :, :-1],
vs.v[:, :, 1:, vs.tau]
* vs.maskV[:, :, 1:]
* 0.5
* vs.dzt[npx.newaxis, npx.newaxis, 1:]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
+ vs.v[:, :, :-1, vs.tau]
* vs.maskV[:, :, :-1]
* 0.5
* vs.dzt[npx.newaxis, npx.newaxis, :-1]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
vs.u_wgrid = update(
vs.u_wgrid, at[:, :, -1], vs.u[:, :, -1, vs.tau] * vs.maskU[:, :, -1] * 0.5 * vs.dzt[-1:] / vs.dzw[-1:]
)
vs.v_wgrid = update(
vs.v_wgrid, at[:, :, -1], vs.v[:, :, -1, vs.tau] * vs.maskV[:, :, -1] * 0.5 * vs.dzt[-1:] / vs.dzw[-1:]
)
# redirect velocity at bottom and at topography
vs.u_wgrid = update(
vs.u_wgrid,
at[:, :, 0],
vs.u_wgrid[:, :, 0] + vs.u[:, :, 0, vs.tau] * vs.maskU[:, :, 0] * 0.5 * vs.dzt[0] / vs.dzw[0],
)
vs.v_wgrid = update(
vs.v_wgrid,
at[:, :, 0],
vs.v_wgrid[:, :, 0] + vs.v[:, :, 0, vs.tau] * vs.maskV[:, :, 0] * 0.5 * vs.dzt[0] / vs.dzw[0],
)
mask = vs.maskW[:-1, :, :-1] * vs.maskW[1:, :, :-1]
vs.u_wgrid = update_add(
vs.u_wgrid,
at[:-1, :, 1:],
(vs.u_wgrid[:-1, :, :-1] * vs.dzw[npx.newaxis, npx.newaxis, :-1] / vs.dzw[npx.newaxis, npx.newaxis, 1:])
* (1.0 - mask),
)
vs.u_wgrid = update_multiply(vs.u_wgrid, at[:-1, :, :-1], mask)
mask = vs.maskW[:, :-1, :-1] * vs.maskW[:, 1:, :-1]
vs.v_wgrid = update_add(
vs.v_wgrid,
at[:, :-1, 1:],
(vs.v_wgrid[:, :-1, :-1] * vs.dzw[npx.newaxis, npx.newaxis, :-1] / vs.dzw[npx.newaxis, npx.newaxis, 1:])
* (1.0 - mask),
)
vs.v_wgrid = update_multiply(vs.v_wgrid, at[:, :-1, :-1], mask)
# vertical advection velocity on W grid from continuity
vs.w_wgrid = update(vs.w_wgrid, at[:, :, 0], 0.0)
vs.w_wgrid = update(
vs.w_wgrid,
at[1:, 1:, :],
npx.cumsum(
-vs.dzw[npx.newaxis, npx.newaxis, :]
* (
(vs.u_wgrid[1:, 1:, :] - vs.u_wgrid[:-1, 1:, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
+ (
vs.cosu[npx.newaxis, 1:, npx.newaxis] * vs.v_wgrid[1:, 1:, :]
- vs.cosu[npx.newaxis, :-1, npx.newaxis] * vs.v_wgrid[1:, :-1, :]
)
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dyt[npx.newaxis, 1:, npx.newaxis])
),
axis=2,
),
)
return KernelOutput(u_wgrid=vs.u_wgrid, v_wgrid=vs.v_wgrid, w_wgrid=vs.w_wgrid)
@veros_kernel
def adv_flux_superbee_wgrid(state, var):
"""
Calculates advection of a tracer defined on Wgrid
"""
vs = state.variables
adv_fe = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_fn = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_ft = allocate(state.dimensions, ("xt", "yt", "zt"))
maskUtr = allocate(state.dimensions, ("xt", "yt", "zw"))
maskUtr = update(maskUtr, at[:-1, :, :], vs.maskW[1:, :, :] * vs.maskW[:-1, :, :])
adv_fe = update(adv_fe, at[1:-2, 2:-2, :], _adv_superbee(state, vs.u_wgrid, var, maskUtr, vs.dxt, axis=0))
maskVtr = allocate(state.dimensions, ("xt", "yt", "zw"))
maskVtr = update(maskVtr, at[:, :-1, :], vs.maskW[:, 1:, :] * vs.maskW[:, :-1, :])
adv_fn = update(adv_fn, at[2:-2, 1:-2, :], _adv_superbee(state, vs.v_wgrid, var, maskVtr, vs.dyt, axis=1))
maskWtr = allocate(state.dimensions, ("xt", "yt", "zw"))
maskWtr = update(maskWtr, at[:, :, :-1], vs.maskW[:, :, 1:] * vs.maskW[:, :, :-1])
adv_ft = update(adv_ft, at[2:-2, 2:-2, :-1], _adv_superbee(state, vs.w_wgrid, var, maskWtr, vs.dzw, axis=2))
adv_ft = update(adv_ft, at[..., -1], 0.0)
return adv_fe, adv_fn, adv_ft
@veros_kernel
def adv_flux_upwind_wgrid(state, var):
"""
Calculates advection of a tracer defined on Wgrid
"""
vs = state.variables
adv_fe = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_fn = allocate(state.dimensions, ("xt", "yt", "zt"))
adv_ft = allocate(state.dimensions, ("xt", "yt", "zt"))
maskUtr = vs.maskW[2:-1, 2:-2, :] * vs.maskW[1:-2, 2:-2, :]
rj = (var[2:-1, 2:-2, :] - var[1:-2, 2:-2, :]) * maskUtr
adv_fe = update(
adv_fe,
at[1:-2, 2:-2, :],
vs.u_wgrid[1:-2, 2:-2, :] * (var[2:-1, 2:-2, :] + var[1:-2, 2:-2, :]) * 0.5
- npx.abs(vs.u_wgrid[1:-2, 2:-2, :]) * rj * 0.5,
)
maskVtr = vs.maskW[2:-2, 2:-1, :] * vs.maskW[2:-2, 1:-2, :]
rj = (var[2:-2, 2:-1, :] - var[2:-2, 1:-2, :]) * maskVtr
adv_fn = update(
adv_fn,
at[2:-2, 1:-2, :],
vs.cosu[npx.newaxis, 1:-2, npx.newaxis]
* vs.v_wgrid[2:-2, 1:-2, :]
* (var[2:-2, 2:-1, :] + var[2:-2, 1:-2, :])
* 0.5
- npx.abs(vs.cosu[npx.newaxis, 1:-2, npx.newaxis] * vs.v_wgrid[2:-2, 1:-2, :]) * rj * 0.5,
)
maskWtr = vs.maskW[2:-2, 2:-2, 1:] * vs.maskW[2:-2, 2:-2, :-1]
rj = (var[2:-2, 2:-2, 1:] - var[2:-2, 2:-2, :-1]) * maskWtr
adv_ft = update(
adv_ft,
at[2:-2, 2:-2, :-1],
vs.w_wgrid[2:-2, 2:-2, :-1] * (var[2:-2, 2:-2, 1:] + var[2:-2, 2:-2, :-1]) * 0.5
- npx.abs(vs.w_wgrid[2:-2, 2:-2, :-1]) * rj * 0.5,
)
adv_ft = update(adv_ft, at[:, :, -1], 0.0)
return adv_fe, adv_fn, adv_ft
from veros.core.density.get_rho import ( # noqa: F401
get_rho,
get_potential_rho,
get_dyn_enthalpy,
get_salt,
get_drhodT,
get_drhodS,
get_drhodp,
get_int_drhodT,
get_int_drhodS,
)
from veros import veros_kernel
from veros.core.density import gsw, linear_eq as lq, nonlinear_eq1 as nq1, nonlinear_eq2 as nq2, nonlinear_eq3 as nq3
@veros_kernel
def get_rho(state, salt_loc, temp_loc, press):
"""
calculate density as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_rho(salt_loc, temp_loc)
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_rho(salt_loc, temp_loc)
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_rho(salt_loc, temp_loc, press)
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_rho(salt_loc, temp_loc)
elif settings.eq_of_state_type == 5:
return gsw.gsw_rho(salt_loc, temp_loc, press)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_potential_rho(state, salt_loc, temp_loc, press_ref=0.0):
"""
calculate potential density as a function of temperature, salinity
and reference pressure
Note:
This is identical to get_rho for eq_of_state_type {1, 2, 4}
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_rho(salt_loc, temp_loc)
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_rho(salt_loc, temp_loc)
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_rho(salt_loc, temp_loc, press_ref)
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_rho(salt_loc, temp_loc)
elif settings.eq_of_state_type == 5:
return gsw.gsw_rho(salt_loc, temp_loc, press_ref)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_dyn_enthalpy(state, salt_loc, temp_loc, press):
"""
calculate dynamic enthalpy as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_dyn_enthalpy(salt_loc, temp_loc, press)
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_dyn_enthalpy(salt_loc, temp_loc, press)
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_dyn_enthalpy(salt_loc, temp_loc, press)
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_dyn_enthalpy(salt_loc, temp_loc, press)
elif settings.eq_of_state_type == 5:
return gsw.gsw_dyn_enthalpy(salt_loc, temp_loc, press)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_salt(state, rho_loc, temp_loc, press_loc):
"""
calculate salinity as a function of density, temperature and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_salt(rho_loc, temp_loc)
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_salt(rho_loc, temp_loc)
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_salt(rho_loc, temp_loc, press_loc)
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_salt(rho_loc, temp_loc)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_drhodT(state, salt_loc, temp_loc, press_loc):
"""
calculate drho/dT as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_drhodT()
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_drhodT(temp_loc)
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_drhodT(temp_loc, press_loc)
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_drhodT(temp_loc)
elif settings.eq_of_state_type == 5:
return gsw.gsw_drhodT(salt_loc, temp_loc, press_loc)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_drhodS(state, salt_loc, temp_loc, press_loc):
"""
calculate drho/dS as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_drhodS()
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_drhodS()
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_drhodS()
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_drhodS()
elif settings.eq_of_state_type == 5:
return gsw.gsw_drhodS(salt_loc, temp_loc, press_loc)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_drhodp(state, salt_loc, temp_loc, press_loc):
"""
calculate drho/dP as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return lq.linear_eq_of_state_drhodp()
elif settings.eq_of_state_type == 2:
return nq1.nonlin1_eq_of_state_drhodp()
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_drhodp(temp_loc)
elif settings.eq_of_state_type == 4:
return nq3.nonlin3_eq_of_state_drhodp()
elif settings.eq_of_state_type == 5:
return gsw.gsw_drhodp(salt_loc, temp_loc, press_loc)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_int_drhodT(state, salt_loc, temp_loc, press_loc):
"""
calculate int_z^0 drho/dT dz' as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return press_loc * lq.linear_eq_of_state_drhodT() # int_z^0rho_T dz = - rho_T z
elif settings.eq_of_state_type == 2:
return press_loc * nq1.nonlin1_eq_of_state_drhodT(temp_loc)
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_int_drhodT(temp_loc, press_loc)
elif settings.eq_of_state_type == 4:
return press_loc * nq3.nonlin3_eq_of_state_drhodT(temp_loc)
elif settings.eq_of_state_type == 5:
return -(1024.0 / 9.81) * gsw.gsw_dHdT(salt_loc, temp_loc, press_loc)
else:
raise ValueError("unknown equation of state")
@veros_kernel
def get_int_drhodS(state, salt_loc, temp_loc, press_loc):
"""
calculate int_z^0 drho/dS dz' as a function of temperature, salinity and pressure
"""
settings = state.settings
if settings.eq_of_state_type == 1:
return press_loc * lq.linear_eq_of_state_drhodS() # int_z^0rho_T dz = - rho_T z
elif settings.eq_of_state_type == 2:
return press_loc * nq1.nonlin1_eq_of_state_drhodS()
elif settings.eq_of_state_type == 3:
return nq2.nonlin2_eq_of_state_int_drhodS(press_loc)
elif settings.eq_of_state_type == 4:
return press_loc * nq3.nonlin3_eq_of_state_drhodS()
elif settings.eq_of_state_type == 5:
return -(1024.0 / 9.81) * gsw.gsw_dHdS(salt_loc, temp_loc, press_loc)
else:
raise ValueError("unknown equation of state")
from veros.core.operators import numpy as npx
from veros import veros_kernel, runtime_settings
"""
==========================================================================
in-situ density, dynamic enthalpy and derivatives
from Absolute Salinity and Conservative
Temperature, using the computationally-efficient 48-term expression for
density in terms of SA, CT and p (IOC et al., 2010).
==========================================================================
"""
v01 = 9.998420897506056e2
v02 = 2.839940833161907e0
v03 = -3.147759265588511e-2
v04 = 1.181805545074306e-3
v05 = -6.698001071123802e0
v06 = -2.986498947203215e-2
v07 = 2.327859407479162e-4
v08 = -3.988822378968490e-2
v09 = 5.095422573880500e-4
v10 = -1.426984671633621e-5
v11 = 1.645039373682922e-7
v12 = -2.233269627352527e-2
v13 = -3.436090079851880e-4
v14 = 3.726050720345733e-6
v15 = -1.806789763745328e-4
v16 = 6.876837219536232e-7
v17 = -3.087032500374211e-7
v18 = -1.988366587925593e-8
v19 = -1.061519070296458e-11
v20 = 1.550932729220080e-10
v21 = 1.0e0
v22 = 2.775927747785646e-3
v23 = -2.349607444135925e-5
v24 = 1.119513357486743e-6
v25 = 6.743689325042773e-10
v26 = -7.521448093615448e-3
v27 = -2.764306979894411e-5
v28 = 1.262937315098546e-7
v29 = 9.527875081696435e-10
v30 = -1.811147201949891e-11
v31 = -3.303308871386421e-5
v32 = 3.801564588876298e-7
v33 = -7.672876869259043e-9
v34 = -4.634182341116144e-11
v35 = 2.681097235569143e-12
v36 = 5.419326551148740e-6
v37 = -2.742185394906099e-5
v38 = -3.212746477974189e-7
v39 = 3.191413910561627e-9
v40 = -1.931012931541776e-12
v41 = -1.105097577149576e-7
v42 = 6.211426728363857e-10
v43 = -1.119011592875110e-10
v44 = -1.941660213148725e-11
v45 = -1.864826425365600e-14
v46 = 1.119522344879478e-14
v47 = -1.200507748551599e-15
v48 = 6.057902487546866e-17
rho0 = 1024.0
@veros_kernel
def gsw_rho(sa, ct, p):
"""
density as a function of T, S, and p
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
==========================================================================
"""
# convert scalar values if necessary
sa, ct, p = npx.asarray(sa), npx.asarray(ct), npx.asarray(p)
sqrtsa = npx.sqrt(sa)
v_hat_denominator = (
v01
+ ct * (v02 + ct * (v03 + v04 * ct))
+ sa * (v05 + ct * (v06 + v07 * ct) + sqrtsa * (v08 + ct * (v09 + ct * (v10 + v11 * ct))))
+ p * (v12 + ct * (v13 + v14 * ct) + sa * (v15 + v16 * ct) + p * (v17 + ct * (v18 + v19 * ct) + v20 * sa))
)
v_hat_numerator = (
v21
+ ct * (v22 + ct * (v23 + ct * (v24 + v25 * ct)))
+ sa
* (
v26
+ ct * (v27 + ct * (v28 + ct * (v29 + v30 * ct)))
+ v36 * sa
+ sqrtsa * (v31 + ct * (v32 + ct * (v33 + ct * (v34 + v35 * ct))))
)
+ p
* (
v37
+ ct * (v38 + ct * (v39 + v40 * ct))
+ sa * (v41 + v42 * ct)
+ p * (v43 + ct * (v44 + v45 * ct + v46 * sa) + p * (v47 + v48 * ct))
)
)
return v_hat_denominator / v_hat_numerator - rho0
@veros_kernel
def gsw_drhodT(sa, ct, p):
"""
d/dT of density
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
==========================================================================
"""
p = npx.asarray(p) # convert scalar value if necessary
a01 = 2.839940833161907e0
a02 = -6.295518531177023e-2
a03 = 3.545416635222918e-3
a04 = -2.986498947203215e-2
a05 = 4.655718814958324e-4
a06 = 5.095422573880500e-4
a07 = -2.853969343267241e-5
a08 = 4.935118121048767e-7
a09 = -3.436090079851880e-4
a10 = 7.452101440691467e-6
a11 = 6.876837219536232e-7
a12 = -1.988366587925593e-8
a13 = -2.123038140592916e-11
a14 = 2.775927747785646e-3
a15 = -4.699214888271850e-5
a16 = 3.358540072460230e-6
a17 = 2.697475730017109e-9
a18 = -2.764306979894411e-5
a19 = 2.525874630197091e-7
a20 = 2.858362524508931e-9
a21 = -7.244588807799565e-11
a22 = 3.801564588876298e-7
a23 = -1.534575373851809e-8
a24 = -1.390254702334843e-10
a25 = 1.072438894227657e-11
a26 = -3.212746477974189e-7
a27 = 6.382827821123254e-9
a28 = -5.793038794625329e-12
a29 = 6.211426728363857e-10
a30 = -1.941660213148725e-11
a31 = -3.729652850731201e-14
a32 = 1.119522344879478e-14
a33 = 6.057902487546866e-17
sqrtsa = npx.sqrt(sa)
v_hat_denominator = (
v01
+ ct * (v02 + ct * (v03 + v04 * ct))
+ sa * (v05 + ct * (v06 + v07 * ct) + sqrtsa * (v08 + ct * (v09 + ct * (v10 + v11 * ct))))
+ p * (v12 + ct * (v13 + v14 * ct) + sa * (v15 + v16 * ct) + p * (v17 + ct * (v18 + v19 * ct) + v20 * sa))
)
v_hat_numerator = (
v21
+ ct * (v22 + ct * (v23 + ct * (v24 + v25 * ct)))
+ sa
* (
v26
+ ct * (v27 + ct * (v28 + ct * (v29 + v30 * ct)))
+ v36 * sa
+ sqrtsa * (v31 + ct * (v32 + ct * (v33 + ct * (v34 + v35 * ct))))
)
+ p
* (
v37
+ ct * (v38 + ct * (v39 + v40 * ct))
+ sa * (v41 + v42 * ct)
+ p * (v43 + ct * (v44 + v45 * ct + v46 * sa) + p * (v47 + v48 * ct))
)
)
dvhatden_dct = (
a01
+ ct * (a02 + a03 * ct)
+ sa * (a04 + a05 * ct + sqrtsa * (a06 + ct * (a07 + a08 * ct)))
+ p * (a09 + a10 * ct + a11 * sa + p * (a12 + a13 * ct))
)
dvhatnum_dct = (
a14
+ ct * (a15 + ct * (a16 + a17 * ct))
+ sa * (a18 + ct * (a19 + ct * (a20 + a21 * ct)) + sqrtsa * (a22 + ct * (a23 + ct * (a24 + a25 * ct))))
+ p * (a26 + ct * (a27 + a28 * ct) + a29 * sa + p * (a30 + a31 * ct + a32 * sa + a33 * p))
)
rec_num = 1.0 / v_hat_numerator
rho = rec_num * v_hat_denominator
return (dvhatden_dct - dvhatnum_dct * rho) * rec_num
@veros_kernel
def gsw_drhodS(sa, ct, p):
"""
d/dS of density
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
==========================================================================
"""
p = npx.asarray(p) # convert scalar value if necessary
b01 = -6.698001071123802e0
b02 = -2.986498947203215e-2
b03 = 2.327859407479162e-4
b04 = -5.983233568452735e-2
b05 = 7.643133860820750e-4
b06 = -2.140477007450431e-5
b07 = 2.467559060524383e-7
b08 = -1.806789763745328e-4
b09 = 6.876837219536232e-7
b10 = 1.550932729220080e-10
b11 = -7.521448093615448e-3
b12 = -2.764306979894411e-5
b13 = 1.262937315098546e-7
b14 = 9.527875081696435e-10
b15 = -1.811147201949891e-11
b16 = -4.954963307079632e-5
b17 = 5.702346883314446e-7
b18 = -1.150931530388857e-8
b19 = -6.951273511674217e-11
b20 = 4.021645853353715e-12
b21 = 1.083865310229748e-5
b22 = -1.105097577149576e-7
b23 = 6.211426728363857e-10
b24 = 1.119522344879478e-14
sqrtsa = npx.sqrt(sa)
v_hat_denominator = (
v01
+ ct * (v02 + ct * (v03 + v04 * ct))
+ sa * (v05 + ct * (v06 + v07 * ct) + sqrtsa * (v08 + ct * (v09 + ct * (v10 + v11 * ct))))
+ p * (v12 + ct * (v13 + v14 * ct) + sa * (v15 + v16 * ct) + p * (v17 + ct * (v18 + v19 * ct) + v20 * sa))
)
v_hat_numerator = (
v21
+ ct * (v22 + ct * (v23 + ct * (v24 + v25 * ct)))
+ sa
* (
v26
+ ct * (v27 + ct * (v28 + ct * (v29 + v30 * ct)))
+ v36 * sa
+ sqrtsa * (v31 + ct * (v32 + ct * (v33 + ct * (v34 + v35 * ct))))
)
+ p
* (
v37
+ ct * (v38 + ct * (v39 + v40 * ct))
+ sa * (v41 + v42 * ct)
+ p * (v43 + ct * (v44 + v45 * ct + v46 * sa) + p * (v47 + v48 * ct))
)
)
dvhatden_dsa = (
b01
+ ct * (b02 + b03 * ct)
+ sqrtsa * (b04 + ct * (b05 + ct * (b06 + b07 * ct)))
+ p * (b08 + b09 * ct + b10 * p)
)
dvhatnum_dsa = (
b11
+ ct * (b12 + ct * (b13 + ct * (b14 + b15 * ct)))
+ sqrtsa * (b16 + ct * (b17 + ct * (b18 + ct * (b19 + b20 * ct))))
+ b21 * sa
+ p * (b22 + ct * (b23 + b24 * p))
)
rec_num = 1.0 / v_hat_numerator
rho = rec_num * v_hat_denominator
return (dvhatden_dsa - dvhatnum_dsa * rho) * rec_num
@veros_kernel
def gsw_drhodP(sa, ct, p):
"""
d/dp of density
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
==========================================================================
"""
p = npx.asarray(p) # convert scalar value if necessary
c01 = -2.233269627352527e-2
c02 = -3.436090079851880e-4
c03 = 3.726050720345733e-6
c04 = -1.806789763745328e-4
c05 = 6.876837219536232e-7
c06 = -6.174065000748422e-7
c07 = -3.976733175851186e-8
c08 = -2.123038140592916e-11
c09 = 3.101865458440160e-10
c10 = -2.742185394906099e-5
c11 = -3.212746477974189e-7
c12 = 3.191413910561627e-9
c13 = -1.931012931541776e-12
c14 = -1.105097577149576e-7
c15 = 6.211426728363857e-10
c16 = -2.238023185750219e-10
c17 = -3.883320426297450e-11
c18 = -3.729652850731201e-14
c19 = 2.239044689758956e-14
c20 = -3.601523245654798e-15
c21 = 1.817370746264060e-16
pa2db = 1e-4
sqrtsa = npx.sqrt(sa)
v_hat_denominator = (
v01
+ ct * (v02 + ct * (v03 + v04 * ct))
+ sa * (v05 + ct * (v06 + v07 * ct) + sqrtsa * (v08 + ct * (v09 + ct * (v10 + v11 * ct))))
+ p * (v12 + ct * (v13 + v14 * ct) + sa * (v15 + v16 * ct) + p * (v17 + ct * (v18 + v19 * ct) + v20 * sa))
)
v_hat_numerator = (
v21
+ ct * (v22 + ct * (v23 + ct * (v24 + v25 * ct)))
+ sa
* (
v26
+ ct * (v27 + ct * (v28 + ct * (v29 + v30 * ct)))
+ v36 * sa
+ sqrtsa * (v31 + ct * (v32 + ct * (v33 + ct * (v34 + v35 * ct))))
)
+ p
* (
v37
+ ct * (v38 + ct * (v39 + v40 * ct))
+ sa * (v41 + v42 * ct)
+ p * (v43 + ct * (v44 + v45 * ct + v46 * sa) + p * (v47 + v48 * ct))
)
)
dvhatden_dp = c01 + ct * (c02 + c03 * ct) + sa * (c04 + c05 * ct) + p * (c06 + ct * (c07 + c08 * ct) + c09 * sa)
dvhatnum_dp = (
c10
+ ct * (c11 + ct * (c12 + c13 * ct))
+ sa * (c14 + c15 * ct)
+ p * (c16 + ct * (c17 + c18 * ct + c19 * sa) + p * (c20 + c21 * ct))
)
rec_num = 1.0 / v_hat_numerator
rho = rec_num * v_hat_denominator
return pa2db * (dvhatden_dp - dvhatnum_dp * rho) * rec_num
@veros_kernel
def gsw_dyn_enthalpy(sa_in, ct_in, p):
"""
Calculates dynamic enthalpy of seawater using the computationally
efficient 48-term expression for density in terms of SA, CT and p
(IOC et al., 2010)
A component due to the constant reference density in Boussinesq
approximation is removed
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
==========================================================================
"""
p = npx.asarray(p) # convert scalar value if necessary
if runtime_settings.pyom_compatibility_mode:
sa = sa_in
ct = ct_in
else:
sa = npx.maximum(1e-1, sa_in) # prevent division by zero
ct = npx.maximum(-12, ct_in) # prevent blowing up for values smaller than -15 degC
db2pa = 1e4 # factor to convert from dbar to Pa
sqrtsa = npx.sqrt(sa)
a0 = (
v21
+ ct * (v22 + ct * (v23 + ct * (v24 + v25 * ct)))
+ sa
* (
v26
+ ct * (v27 + ct * (v28 + ct * (v29 + v30 * ct)))
+ v36 * sa
+ sqrtsa * (v31 + ct * (v32 + ct * (v33 + ct * (v34 + v35 * ct))))
)
)
a1 = v37 + ct * (v38 + ct * (v39 + v40 * ct)) + sa * (v41 + v42 * ct)
a2 = v43 + ct * (v44 + v45 * ct + v46 * sa)
a3 = v47 + v48 * ct
b0 = (
v01
+ ct * (v02 + ct * (v03 + v04 * ct))
+ sa * (v05 + ct * (v06 + v07 * ct) + sqrtsa * (v08 + ct * (v09 + ct * (v10 + v11 * ct))))
)
b1 = 0.5 * (v12 + ct * (v13 + v14 * ct) + sa * (v15 + v16 * ct))
b2 = v17 + ct * (v18 + v19 * ct) + v20 * sa
b1sq = b1 * b1
sqrt_disc = npx.sqrt(b1sq - b0 * b2)
cn = a0 + (2 * a3 * b0 * b1 / b2 - a2 * b0) / b2
cm = a1 + (4 * a3 * b1sq / b2 - a3 * b0 - 2 * a2 * b1) / b2
ca = b1 - sqrt_disc
cb = b1 + sqrt_disc
part = (cn * b2 - cm * b1) / (b2 * (cb - ca))
Hd = db2pa * (
p * (a2 - 2.0 * a3 * b1 / b2 + 0.5 * a3 * p) / b2
+ (cm / (2.0 * b2)) * npx.log(1.0 + p * (2.0 * b1 + b2 * p) / b0)
+ part * npx.log(1.0 + (b2 * p * (cb - ca)) / (ca * (cb + b2 * p)))
)
return Hd - p * db2pa / rho0
@veros_kernel
def gsw_dHdT(sa_in, ct_in, p):
"""
d/dT of dynamic enthalpy, analytical derivative
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
"""
p = npx.asarray(p) # convert scalar value if necessary
sa = npx.maximum(1e-1, sa_in) # prevent division by zero
ct = npx.maximum(-12, ct_in) # prevent blowing up for values smaller than -15 degC
t1 = v45 * ct
t2 = 0.2e1 * t1
t3 = v46 * sa
t4 = 0.5 * v12
t5 = v14 * ct
t7 = ct * (v13 + t5)
t8 = 0.5 * t7
t11 = sa * (v15 + v16 * ct)
t12 = 0.5 * t11
t13 = t4 + t8 + t12
t15 = v19 * ct
t19 = v17 + ct * (v18 + t15) + v20 * sa
t20 = 1.0 / t19
t24 = v47 + v48 * ct
t25 = 0.5 * v13
t26 = 1.0 * t5
t27 = sa * v16
t28 = 0.5 * t27
t29 = t25 + t26 + t28
t33 = t24 * t13
t34 = t19**2
t35 = 1.0 / t34
t37 = v18 + 2.0 * t15
t38 = t35 * t37
t48 = ct * (v44 + t1 + t3)
t57 = v40 * ct
t59 = ct * (v39 + t57)
t64 = t13**2
t68 = t20 * t29
t71 = t24 * t64
t74 = v04 * ct
t76 = ct * (v03 + t74)
t79 = v07 * ct
t82 = npx.sqrt(sa)
t83 = v11 * ct
t85 = ct * (v10 + t83)
t92 = v01 + ct * (v02 + t76) + sa * (v05 + ct * (v06 + t79) + t82 * (v08 + ct * (v09 + t85)))
t93 = v48 * t92
t105 = v02 + t76 + ct * (v03 + 2.0 * t74) + sa * (v06 + 2.0 * t79 + t82 * (v09 + t85 + ct * (v10 + 2.0 * t83)))
t106 = t24 * t105
t107 = v44 + t2 + t3
t110 = v43 + t48
t117 = t24 * t92
t120 = 4.0 * t71 * t20 - t117 - 2.0 * t110 * t13
t123 = (
v38
+ t59
+ ct * (v39 + 2.0 * t57)
+ sa * v42
+ (4.0 * v48 * t64 * t20 + 8.0 * t33 * t68 - 4.0 * t71 * t38 - t93 - t106 - 2.0 * t107 * t13 - 2.0 * t110 * t29)
* t20
- t120 * t35 * t37
)
t128 = t19 * p
t130 = p * (1.0 * v12 + 1.0 * t7 + 1.0 * t11 + t128)
t131 = 1.0 / t92
t133 = 1.0 + t130 * t131
t134 = npx.log(t133)
t143 = v37 + ct * (v38 + t59) + sa * (v41 + v42 * ct) + t120 * t20
t152 = t37 * p
t156 = t92**2
t165 = v25 * ct
t167 = ct * (v24 + t165)
t169 = ct * (v23 + t167)
t175 = v30 * ct
t177 = ct * (v29 + t175)
t179 = ct * (v28 + t177)
t185 = v35 * ct
t187 = ct * (v34 + t185)
t189 = ct * (v33 + t187)
t199 = t13 * t20
t217 = 2.0 * t117 * t199 - t110 * t92
t234 = (
v21
+ ct * (v22 + t169)
+ sa * (v26 + ct * (v27 + t179) + v36 * sa + t82 * (v31 + ct * (v32 + t189)))
+ t217 * t20
)
t241 = t64 - t92 * t19
t242 = npx.sqrt(t241)
t243 = 1.0 / t242
t244 = t4 + t8 + t12 - t242
t245 = 1.0 / t244
t247 = t4 + t8 + t12 + t242 + t128
t248 = 1.0 / t247
t249 = t242 * t245 * t248
t252 = 1.0 + 2.0 * t128 * t249
t253 = npx.log(t252)
t254 = t243 * t253
t259 = t234 * t19 - t143 * t13
t264 = t259 * t20
t272 = 2.0 * t13 * t29 - t105 * t19 - t92 * t37
t282 = t128 * t242
t283 = t244**2
t287 = t243 * t272 / 2.0
t292 = t247**2
t305 = (
0.1e5
* p
* (v44 + t2 + t3 - 2.0 * v48 * t13 * t20 - 2.0 * t24 * t29 * t20 + 2.0 * t33 * t38 + 0.5 * v48 * p)
* t20
- 0.1e5 * p * (v43 + t48 - 2.0 * t33 * t20 + 0.5 * t24 * p) * t38
+ 0.5e4 * t123 * t20 * t134
- 0.5e4 * t143 * t35 * t134 * t37
+ 0.5e4 * t143 * t20 * (p * (1.0 * v13 + 2.0 * t5 + 1.0 * t27 + t152) * t131 - t130 / t156 * t105) / t133
+ 0.5e4
* (
(
v22
+ t169
+ ct * (v23 + t167 + ct * (v24 + 2.0 * t165))
+ sa
* (
v27
+ t179
+ ct * (v28 + t177 + ct * (v29 + 2.0 * t175))
+ t82 * (v32 + t189 + ct * (v33 + t187 + ct * (v34 + 2.0 * t185)))
)
+ (
2.0 * t93 * t199
+ 2.0 * t106 * t199
+ 2.0 * t117 * t68
- 2.0 * t117 * t13 * t35 * t37
- t107 * t92
- t110 * t105
)
* t20
- t217 * t35 * t37
)
* t19
+ t234 * t37
- t123 * t13
- t143 * t29
)
* t20
* t254
- 0.5e4 * t259 * t35 * t254 * t37
- 0.25e4 * t264 / t242 / t241 * t253 * t272
+ 0.5e4
* t264
* t243
* (
2.0 * t152 * t249
+ t128 * t243 * t245 * t248 * t272
- 2.0 * t282 / t283 * t248 * (t25 + t26 + t28 - t287)
- 2.0 * t282 * t245 / t292 * (t25 + t26 + t28 + t287 + t152)
)
/ t252
)
return t305
@veros_kernel
def gsw_dHdS(sa_in, ct_in, p):
"""
d/dS of dynamic enthalpy, analytical derivative
sa : Absolute Salinity [g/kg]
ct : Conservative Temperature [deg C]
p : sea pressure [dbar]
"""
p = npx.asarray(p) # convert scalar value if necessary
sa = npx.maximum(1e-1, sa_in) # prevent division by zero
ct = npx.maximum(-12.0, ct_in) # prevent blowing up for values smaller than -15 degC
t1 = ct * v46
t3 = v47 + v48 * ct
t4 = 0.5 * v15
t5 = v16 * ct
t6 = 0.5 * t5
t7 = t4 + t6
t13 = v17 + ct * (v18 + v19 * ct) + v20 * sa
t14 = 1.0 / t13
t17 = 0.5 * v12
t20 = ct * (v13 + v14 * ct)
t21 = 0.5 * t20
t23 = sa * (v15 + t5)
t24 = 0.5 * t23
t25 = t17 + t21 + t24
t26 = t3 * t25
t27 = t13**2
t28 = 1.0 / t27
t29 = t28 * v20
t39 = ct * (v44 + v45 * ct + v46 * sa)
t48 = v42 * ct
t49 = t14 * t7
t52 = t25**2
t53 = t3 * t52
t58 = ct * (v06 + v07 * ct)
t59 = npx.sqrt(sa)
t66 = t59 * (v08 + ct * (v09 + ct * (v10 + v11 * ct)))
t68 = v05 + t58 + 3.0 / 2.0 * t66
t69 = t3 * t68
t72 = v43 + t39
t86 = v01 + ct * (v02 + ct * (v03 + v04 * ct)) + sa * (v05 + t58 + t66)
t87 = t3 * t86
t90 = 4.0 * t53 * t14 - t87 - 2.0 * t72 * t25
t93 = (
v41 + t48 + (8.0 * t26 * t49 - 4.0 * t53 * t29 - t69 - 2.0 * t1 * t25 - 2.0 * t72 * t7) * t14 - t90 * t28 * v20
)
t98 = t13 * p
t100 = p * (1.0 * v12 + 1.0 * t20 + 1.0 * t23 + t98)
t101 = 1.0 / t86
t103 = 1.0 + t100 * t101
t104 = npx.log(t103)
t115 = v37 + ct * (v38 + ct * (v39 + v40 * ct)) + sa * (v41 + t48) + t90 * t14
t123 = v20 * p
t127 = t86**2
t142 = ct * (v27 + ct * (v28 + ct * (v29 + v30 * ct)))
t143 = v36 * sa
t151 = v31 + ct * (v32 + ct * (v33 + ct * (v34 + v35 * ct)))
t152 = t59 * t151
t158 = t25 * t14
t174 = 2.0 * t87 * t158 - t72 * t86
t189 = v21 + ct * (v22 + ct * (v23 + ct * (v24 + v25 * ct))) + sa * (v26 + t142 + t143 + t152) + t174 * t14
t196 = t52 - t86 * t13
t197 = npx.sqrt(t196)
t198 = 1.0 / t197
t199 = t17 + t21 + t24 - t197
t200 = 1.0 / t199
t202 = t17 + t21 + t24 + t197 + t98
t203 = 1.0 / t202
t204 = t197 * t200 * t203
t207 = 1.0 + 2.0 * t98 * t204
t208 = npx.log(t207)
t209 = t198 * t208
t214 = t189 * t13 - t115 * t25
t219 = t214 * t14
t227 = 2.0 * t25 * t7 - t68 * t13 - t86 * v20
t237 = t98 * t197
t238 = t199**2
t242 = t198 * t227 / 2.0
t247 = t202**2
t260 = (
0.1e5 * p * (t1 - 2.0 * t3 * t7 * t14 + 2.0 * t26 * t29) * t14
- 0.1e5 * p * (v43 + t39 - 2.0 * t26 * t14 + 0.5 * t3 * p) * t29
+ 0.5e4 * t93 * t14 * t104
- 0.5e4 * t115 * t28 * t104 * v20
+ 0.5e4 * t115 * t14 * (p * (1.0 * v15 + 1.0 * t5 + t123) * t101 - t100 / t127 * t68) / t103
+ 0.5e4
* (
(
v26
+ t142
+ t143
+ t152
+ sa * (v36 + 1.0 / t59 * t151 / 2.0)
+ (2.0 * t69 * t158 + 2.0 * t87 * t49 - 2.0 * t87 * t25 * t28 * v20 - t1 * t86 - t72 * t68) * t14
- t174 * t28 * v20
)
* t13
+ t189 * v20
- t93 * t25
- t115 * t7
)
* t14
* t209
- 0.5e4 * t214 * t28 * t209 * v20
- 0.25e4 * t219 / t197 / t196 * t208 * t227
+ 0.5e4
* t219
* t198
* (
2.0 * t123 * t204
+ t98 * t198 * t200 * t203 * t227
- 2.0 * t237 / t238 * t203 * (t4 + t6 - t242)
- 2.0 * t237 * t200 / t247 * (t4 + t6 + t242 + t123)
)
/ t207
)
return t260
"""
==========================================================================
linear equation of state
input is Salinity sa in g/kg,
pot. temperature ct in deg C
==========================================================================
"""
from veros import veros_kernel
rho0 = 1024.0
theta0 = 283.0 - 273.15
S0 = 35.0
betaT = 1.67e-4
betaS = 0.78e-3
grav = 9.81
z0 = 0.0
@veros_kernel
def linear_eq_of_state_rho(sa, ct):
return -(betaT * (ct - theta0) - betaS * (sa - S0)) * rho0
@veros_kernel
def linear_eq_of_state_dyn_enthalpy(sa, ct, p):
zz = -p - z0
thetas = ct - theta0
return grav * zz * (-betaT * thetas + betaS * (sa - S0))
@veros_kernel
def linear_eq_of_state_salt(rho, ct):
return (rho + betaT * (ct - theta0) * rho0) / (betaS * rho0) + S0
@veros_kernel
def linear_eq_of_state_drhodT():
return -betaT * rho0
@veros_kernel
def linear_eq_of_state_drhodS():
return betaS * rho0
@veros_kernel
def linear_eq_of_state_drhodp():
return 0.0
"""
==========================================================================
non-linear equation of state from Vallis 2008
input is Salinity sa in g/kg,
pot. temperature ct in deg C, no pressure dependency
==========================================================================
"""
from veros import veros_kernel
rho0 = 1024.0
theta0 = 283.0 - 273.15
S0 = 35.0
betaT = 1.67e-4
betaTs = 1e-5 / 2.0
betaS = 0.78e-3
grav = 9.81
z0 = 0.0
@veros_kernel
def nonlin1_eq_of_state_rho(sa, ct):
thetas = ct - theta0
return -(betaT * thetas + betaTs * thetas**2 - betaS * (sa - S0)) * rho0
@veros_kernel
def nonlin1_eq_of_state_dyn_enthalpy(sa, ct, p):
zz = -p - z0
thetas = ct - theta0
return grav * zz * (-betaT * thetas - betaTs * thetas**2 + betaS * (sa - S0))
@veros_kernel
def nonlin1_eq_of_state_salt(rho, ct):
thetas = ct - theta0
return (rho + (betaT * thetas + betaTs * thetas**2) * rho0) / (betaS * rho0) + S0
@veros_kernel
def nonlin1_eq_of_state_drhodT(ct):
thetas = ct - theta0
return -(betaT + 2 * betaTs * thetas) * rho0
@veros_kernel
def nonlin1_eq_of_state_drhodS():
return betaS * rho0
@veros_kernel
def nonlin1_eq_of_state_drhodp():
return 0.0
"""
==========================================================================
non-linear equation of state from Vallis 2008
input is Salinity sa in g/kg,
pot. temperature ct in deg C and
pressure p in dbar
==========================================================================
"""
from veros import veros_kernel
rho0 = 1024.0
z0 = 0.0
theta0 = 283.0 - 273.15
S0 = 35.0
grav = 9.81
cs0 = 1490.0
betaT = 1.67e-4
betaTs = 1e-5
betaS = 0.78e-3
gammas = 1.1e-8
@veros_kernel
def nonlin2_eq_of_state_rho(sa, ct, p):
zz = -p - z0
thetas = ct - theta0
return (
-(
grav * zz / cs0**2
+ betaT * (1 - gammas * grav * zz * rho0) * thetas
+ betaTs / 2 * thetas**2
- betaS * (sa - S0)
)
* rho0
)
@veros_kernel
def nonlin2_eq_of_state_dyn_enthalpy(sa, ct, p):
zz = -p - z0
thetas = ct - theta0
return grav * 0.5 * zz**2 * (-grav / cs0**2 + betaT * grav * rho0 * gammas * thetas) + grav * zz * (
-betaT * thetas - betaTs * thetas**2 + betaS * (sa - S0)
)
@veros_kernel
def nonlin2_eq_of_state_salt(rho, ct, p):
zz = -p - z0
thetas = ct - theta0
return (
rho / rho0
+ (grav * zz / cs0**2 + betaT * (1 - gammas * grav * zz * rho0) * thetas + betaTs / 2 * thetas**2)
) / betaS + S0
@veros_kernel
def nonlin2_eq_of_state_drhodT(ct, p):
zz = -p - z0
thetas = ct - theta0
return -(betaTs * thetas + betaT * (1 - gammas * grav * zz * rho0)) * rho0
@veros_kernel
def nonlin2_eq_of_state_drhodS():
return betaS * rho0
@veros_kernel
def nonlin2_eq_of_state_drhodP(ct):
thetas = ct - theta0
return 1 / cs0**2 - betaT * gammas * rho0 * thetas
@veros_kernel
def nonlin2_eq_of_state_int_drhodT(ct, p):
zz = -p - z0
thetas = ct - theta0
return rho0 * zz * (betaT + betaTs * thetas) - rho0 * betaT * gammas * grav * rho0 * zz**2 / 2
@veros_kernel
def nonlin2_eq_of_state_int_drhodS(p):
zz = -p - z0
return -betaS * rho0 * zz
"""
==========================================================================
non-linear equation of state, no salinity dependency
input is Salinity sa in g/kg,
pot. temperature ct in deg C , no pressure dependency
==========================================================================
"""
from veros import veros_kernel
rho0 = 1024.0
theta0 = 283.0 - 273.15
S0 = 35.0
betaT = 1.67e-4
betaTs = 1e-5 / 2.0
betaS = 0
grav = 9.81
z0 = 0.0
@veros_kernel
def nonlin3_eq_of_state_rho(sa, ct):
thetas = ct - theta0
return -(betaT * thetas + betaTs * thetas**2 - betaS * (sa - S0)) * rho0
@veros_kernel
def nonlin3_eq_of_state_dyn_enthalpy(sa, ct, p):
zz = -p - z0
thetas = ct - theta0
return grav * zz * (-betaT * thetas - betaTs * thetas**2 + betaS * (sa - S0))
@veros_kernel
def nonlin3_eq_of_state_salt(rho, ct):
thetas = ct - theta0
return (rho + (betaT * thetas + betaTs * thetas**2) * rho0) / (betaS * rho0) + S0
@veros_kernel
def nonlin3_eq_of_state_drhodT(ct):
thetas = ct - theta0
return -(betaT + 2 * betaTs * thetas) * rho0
@veros_kernel
def nonlin3_eq_of_state_drhodS():
return betaS * rho0
@veros_kernel
def nonlin3_eq_of_state_drhodp():
return 0.0
from veros.core.operators import numpy as npx
from veros import veros_kernel, KernelOutput
from veros.variables import allocate
from veros.core import utilities
from veros.core.operators import update, update_add, update_multiply, at
@veros_kernel
def compute_dissipation(state, int_drhodX, flux_east, flux_north):
vs = state.variables
settings = state.settings
diss = allocate(state.dimensions, ("xt", "yt", "zt"))
diss = update(
diss,
at[1:-1, 1:-1, :],
0.5
* settings.grav
/ settings.rho_0
* (
(int_drhodX[2:, 1:-1, :] - int_drhodX[1:-1, 1:-1, :]) * flux_east[1:-1, 1:-1, :]
+ (int_drhodX[1:-1, 1:-1, :] - int_drhodX[:-2, 1:-1, :]) * flux_east[:-2, 1:-1, :]
)
/ (vs.dxt[1:-1, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, 1:-1, npx.newaxis])
+ 0.5
* settings.grav
/ settings.rho_0
* (
(int_drhodX[1:-1, 2:, :] - int_drhodX[1:-1, 1:-1, :]) * flux_north[1:-1, 1:-1, :]
+ (int_drhodX[1:-1, 1:-1, :] - int_drhodX[1:-1, :-2, :]) * flux_north[1:-1, :-2, :]
)
/ (vs.dyt[npx.newaxis, 1:-1, npx.newaxis] * vs.cost[npx.newaxis, 1:-1, npx.newaxis]),
)
return diss
@veros_kernel
def dissipation_on_wgrid(state, diss, ks):
vs = state.variables
settings = state.settings
land_mask, water_mask, edge_mask = utilities.create_water_masks(ks, settings.nz)
water_mask = npx.logical_and(water_mask, npx.logical_not(edge_mask))
dzw_pad = utilities.pad_z_edges(vs.dzw)
diss_w = allocate(state.dimensions, ("xt", "yt", "zt"))
diss_w = update(
diss_w,
at[:, :, :-1],
(
0.5 * (diss[:, :, :-1] + diss[:, :, 1:])
+ 0.5 * (diss[:, :, :-1] * dzw_pad[npx.newaxis, npx.newaxis, :-3] / vs.dzw[npx.newaxis, npx.newaxis, :-1])
)
* edge_mask[:, :, :-1]
+ 0.5 * (diss[:, :, :-1] + diss[:, :, 1:]) * water_mask[:, :, :-1],
)
diss_w = update(diss_w, at[:, :, -1], diss[:, :, -1] * land_mask)
return diss_w
@veros_kernel
def tempsalt_biharmonic(state):
"""
biharmonic mixing of temp and salinity,
dissipation of dyn. Enthalpy is stored
"""
vs = state.variables
settings = state.settings
fxa = npx.sqrt(abs(settings.K_hbi))
# update temp
dtemp, flux_east, flux_north = biharmonic_diffusion(state, vs.temp[:, :, :, vs.tau], fxa)
vs.dtemp_hmix = update(vs.dtemp_hmix, at[1:, 1:, :], dtemp[1:, 1:, :])
vs.temp = update_add(vs.temp, at[:, :, :, vs.taup1], settings.dt_tracer * vs.dtemp_hmix * vs.maskT)
if settings.enable_conserve_energy:
diss = compute_dissipation(state, vs.int_drhodT[..., vs.tau], flux_east, flux_north)
vs.P_diss_hmix = dissipation_on_wgrid(state, diss, vs.kbot)
# update salt
dsalt, flux_east, flux_north = biharmonic_diffusion(state, vs.salt[:, :, :, vs.tau], fxa)
vs.dsalt_hmix = update(vs.dsalt_hmix, at[1:, 1:, :], dsalt[1:, 1:, :])
vs.salt = update_add(vs.salt, at[:, :, :, vs.taup1], settings.dt_tracer * vs.dsalt_hmix * vs.maskT)
if settings.enable_conserve_energy:
diss = compute_dissipation(state, vs.int_drhodS[..., vs.tau], flux_east, flux_north)
vs.P_diss_hmix = vs.P_diss_hmix + dissipation_on_wgrid(state, diss, vs.kbot)
return KernelOutput(
temp=vs.temp, salt=vs.salt, dtemp_hmix=vs.dtemp_hmix, dsalt_hmix=vs.dsalt_hmix, P_diss_hmix=vs.P_diss_hmix
)
@veros_kernel
def tempsalt_diffusion(state):
"""
Diffusion of temp and salinity,
dissipation of dyn. Enthalpy is stored
"""
vs = state.variables
settings = state.settings
# horizontal diffusion of temperature
dtemp, flux_east, flux_north = horizontal_diffusion(state, vs.temp[:, :, :, vs.tau], settings.K_h)
vs.dtemp_hmix = update(vs.dtemp_hmix, at[1:, 1:, :], dtemp[1:, 1:, :])
vs.temp = update_add(vs.temp, at[:, :, :, vs.taup1], settings.dt_tracer * vs.dtemp_hmix * vs.maskT)
if settings.enable_conserve_energy:
diss = compute_dissipation(state, vs.int_drhodT[..., vs.tau], flux_east, flux_north)
vs.P_diss_hmix = dissipation_on_wgrid(state, diss, vs.kbot)
# horizontal diffusion of salinity
dsalt, flux_east, flux_north = horizontal_diffusion(state, vs.salt[:, :, :, vs.tau], settings.K_h)
vs.dsalt_hmix = update(vs.dsalt_hmix, at[1:, 1:, :], dsalt[1:, 1:, :])
vs.salt = update_add(vs.salt, at[:, :, :, vs.taup1], settings.dt_tracer * vs.dsalt_hmix * vs.maskT)
if settings.enable_conserve_energy:
diss = compute_dissipation(state, vs.int_drhodS[..., vs.tau], flux_east, flux_north)
vs.P_diss_hmix = vs.P_diss_hmix + dissipation_on_wgrid(state, diss, vs.kbot)
return KernelOutput(
temp=vs.temp, salt=vs.salt, dtemp_hmix=vs.dtemp_hmix, dsalt_hmix=vs.dsalt_hmix, P_diss_hmix=vs.P_diss_hmix
)
@veros_kernel
def tempsalt_sources(state):
"""
Sources of temp and salinity,
effect on dyn. Enthalpy is stored
"""
vs = state.variables
settings = state.settings
vs.temp = update_add(vs.temp, at[:, :, :, vs.taup1], settings.dt_tracer * vs.temp_source * vs.maskT)
vs.salt = update_add(vs.salt, at[:, :, :, vs.taup1], settings.dt_tracer * vs.salt_source * vs.maskT)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yt", "zt"))
diss = update(
diss,
at[1:-1, 1:-1, :],
-settings.grav
/ settings.rho_0
* vs.maskT[1:-1, 1:-1, :]
* (
vs.int_drhodT[1:-1, 1:-1, :, vs.tau] * vs.temp_source[1:-1, 1:-1]
+ vs.int_drhodS[1:-1, 1:-1, :, vs.tau] * vs.salt_source[1:-1, 1:-1]
),
)
vs.P_diss_sources = dissipation_on_wgrid(state, diss, vs.kbot)
return KernelOutput(temp=vs.temp, salt=vs.salt, P_diss_sources=vs.P_diss_sources)
@veros_kernel
def biharmonic_diffusion(state, tr, diffusivity):
"""
Biharmonic mixing of tracer tr
"""
vs = state.variables
settings = state.settings
del2 = allocate(state.dimensions, ("xt", "yt", "zt"))
dtr = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_east = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_east = update(
flux_east,
at[:-1, :, :],
-diffusivity
* (tr[1:, :, :] - tr[:-1, :, :])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskU[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
-diffusivity
* (tr[:, 1:, :] - tr[:, :-1, :])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
del2 = update(
del2,
at[1:, 1:, :],
vs.maskT[1:, 1:, :]
* (flux_east[1:, 1:, :] - flux_east[:-1, 1:, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
+ (flux_north[1:, 1:, :] - flux_north[1:, :-1, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dyt[npx.newaxis, 1:, npx.newaxis]),
)
del2 = utilities.enforce_boundaries(del2, settings.enable_cyclic_x)
flux_east = update(
flux_east,
at[:-1, :, :],
diffusivity
* (del2[1:, :, :] - del2[:-1, :, :])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskU[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
diffusivity
* (del2[:, 1:, :] - del2[:, :-1, :])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
dtr = update(
dtr,
at[1:, 1:, :],
(flux_east[1:, 1:, :] - flux_east[:-1, 1:, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
+ (flux_north[1:, 1:, :] - flux_north[1:, :-1, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dyt[npx.newaxis, 1:, npx.newaxis]),
)
dtr = dtr * vs.maskT
return dtr, flux_east, flux_north
@veros_kernel
def horizontal_diffusion(state, tr, diffusivity):
"""
Diffusion of tracer tr
"""
vs = state.variables
settings = state.settings
dtr_hmix = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_east = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yt", "zt"))
# horizontal diffusion of tracer
flux_east = update(
flux_east,
at[:-1, :, :],
diffusivity
* (tr[1:, :, :] - tr[:-1, :, :])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskU[:-1, :, :],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(
flux_north,
at[:, :-1, :],
diffusivity
* (tr[:, 1:, :] - tr[:, :-1, :])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_north = update(flux_north, at[:, -1, :], 0.0)
if settings.enable_hor_friction_cos_scaling:
flux_east = update_multiply(
flux_east, at[...], vs.cost[npx.newaxis, :, npx.newaxis] ** settings.hor_friction_cosPower
)
flux_north = update_multiply(
flux_north, at[...], vs.cosu[npx.newaxis, :, npx.newaxis] ** settings.hor_friction_cosPower
)
dtr_hmix = update(
dtr_hmix,
at[1:, 1:, :],
(
(flux_east[1:, 1:, :] - flux_east[:-1, 1:, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
+ (flux_north[1:, 1:, :] - flux_north[1:, :-1, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dyt[npx.newaxis, 1:, npx.newaxis])
)
* vs.maskT[1:, 1:, :],
)
return dtr_hmix, flux_east, flux_north
from veros.core.operators import numpy as npx
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core import utilities, advection
from veros.core.operators import update, update_add, at
@veros_routine
def set_eke_diffusivities(state):
vs = state.variables
settings = state.settings
eke_diff_out = set_eke_diffusivities_kernel(state)
vs.update(eke_diff_out)
if settings.enable_TEM_friction:
kappa_gm_out = update_kappa_gm(state)
vs.update(kappa_gm_out)
@veros_kernel
def update_kappa_gm(state):
vs = state.variables
kappa_gm = (
vs.K_gm
* npx.minimum(0.01, vs.coriolis_t[..., npx.newaxis] ** 2 / npx.maximum(1e-9, vs.Nsqr[..., vs.tau]))
* vs.maskW
)
return KernelOutput(kappa_gm=kappa_gm)
@veros_kernel
def set_eke_diffusivities_kernel(state):
"""
set skew diffusivity K_gm and isopycnal diffusivity K_iso
set also vertical viscosity if TEM formalism is chosen
"""
vs = state.variables
settings = state.settings
if settings.enable_eke:
"""
calculate Rossby radius as minimum of mid-latitude and equatorial R. rad.
"""
C_rossby = npx.sum(
npx.sqrt(npx.maximum(0.0, vs.Nsqr[:, :, :, vs.tau]))
* vs.dzw[npx.newaxis, npx.newaxis, :]
* vs.maskW[:, :, :]
/ settings.pi,
axis=2,
)
vs.L_rossby = npx.minimum(
C_rossby / npx.maximum(npx.abs(vs.coriolis_t), 1e-16), npx.sqrt(C_rossby / npx.maximum(2 * vs.beta, 1e-16))
)
"""
calculate vertical viscosity and skew diffusivity
"""
vs.sqrteke = npx.sqrt(npx.maximum(0.0, vs.eke[:, :, :, vs.tau]))
vs.L_rhines = npx.sqrt(vs.sqrteke / npx.maximum(vs.beta[..., npx.newaxis], 1e-16))
vs.eke_len = npx.maximum(
settings.eke_lmin,
npx.minimum(settings.eke_cross * vs.L_rossby[..., npx.newaxis], settings.eke_crhin * vs.L_rhines),
)
vs.K_gm = npx.minimum(settings.eke_k_max, settings.eke_c_k * vs.eke_len * vs.sqrteke)
else:
"""
use fixed GM diffusivity
"""
vs.K_gm = update(vs.K_gm, at[...], settings.K_gm_0)
if settings.enable_eke and settings.enable_eke_isopycnal_diffusion:
vs.K_iso = update(vs.K_iso, at[...], vs.K_gm)
else:
vs.K_iso = update(vs.K_iso, at[...], settings.K_iso_0) # always constant
if not settings.enable_eke:
return KernelOutput(K_gm=vs.K_gm, K_iso=vs.K_iso)
return KernelOutput(
L_rossby=vs.L_rossby, L_rhines=vs.L_rhines, eke_len=vs.eke_len, sqrteke=vs.sqrteke, K_gm=vs.K_gm, K_iso=vs.K_iso
)
@veros_routine
def integrate_eke(state):
vs = state.variables
vs.update(integrate_eke_kernel(state))
@veros_kernel
def integrate_eke_kernel(state):
"""
integrate EKE equation on W grid
"""
vs = state.variables
settings = state.settings
c_int = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_east = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_top = allocate(state.dimensions, ("xt", "yt", "zt"))
"""
forcing by dissipation by lateral friction and GM using TRM formalism or skew diffusion
"""
forc = vs.K_diss_gm + vs.K_diss_h - vs.P_diss_skew
"""
store transfer due to isopycnal and horizontal mixing from dyn. enthalpy
by non-linear eq.of state either to EKE or to heat
"""
if not settings.enable_store_cabbeling_heat:
forc = forc - vs.P_diss_hmix - vs.P_diss_iso
conditional_outputs = {}
"""
dissipation by local interior loss of balance with constant coefficient
"""
c_int = settings.eke_c_eps * vs.sqrteke / vs.eke_len * vs.maskW
"""
vertical diffusion of EKE,forcing and dissipation
"""
_, water_mask, edge_mask = utilities.create_water_masks(vs.kbot[2:-2, 2:-2], settings.nz)
delta, a_tri, b_tri, c_tri, d_tri = (
allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2, :] for _ in range(5)
)
delta = update(
delta,
at[:, :, :-1],
settings.dt_tracer
/ vs.dzt[npx.newaxis, npx.newaxis, 1:]
* 0.5
* (vs.kappaM[2:-2, 2:-2, :-1] + vs.kappaM[2:-2, 2:-2, 1:])
* settings.alpha_eke,
)
a_tri = update(a_tri, at[:, :, 1:-1], -delta[:, :, :-2] / vs.dzw[1:-1])
a_tri = update(a_tri, at[:, :, -1], -delta[:, :, -2] / (0.5 * vs.dzw[-1]))
b_tri = update(
b_tri,
at[:, :, 1:-1],
1 + (delta[:, :, 1:-1] + delta[:, :, :-2]) / vs.dzw[1:-1] + settings.dt_tracer * c_int[2:-2, 2:-2, 1:-1],
)
b_tri = update(
b_tri, at[:, :, -1], 1 + delta[:, :, -2] / (0.5 * vs.dzw[-1]) + settings.dt_tracer * c_int[2:-2, 2:-2, -1]
)
b_tri_edge = 1 + delta / vs.dzw[npx.newaxis, npx.newaxis, :] + settings.dt_tracer * c_int[2:-2, 2:-2, :]
c_tri = update(c_tri, at[:, :, :-1], -delta[:, :, :-1] / vs.dzw[npx.newaxis, npx.newaxis, :-1])
d_tri = update(d_tri, at[:, :, :], vs.eke[2:-2, 2:-2, :, vs.tau] + settings.dt_tracer * forc[2:-2, 2:-2, :])
sol = utilities.solve_implicit(a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, edge_mask=edge_mask)
vs.eke = update(vs.eke, at[2:-2, 2:-2, :, vs.taup1], npx.where(water_mask, sol, vs.eke[2:-2, 2:-2, :, vs.taup1]))
"""
store eke dissipation
"""
vs.eke_diss_iw = c_int * vs.eke[:, :, :, vs.taup1]
vs.eke_diss_tke = update(vs.eke_diss_tke, at[...], 0.0)
"""
add tendency due to lateral diffusion
"""
flux_east = update(
flux_east,
at[:-1, :, :],
0.5
* npx.maximum(500.0, vs.K_gm[:-1, :, :] + vs.K_gm[1:, :, :])
* (vs.eke[1:, :, :, vs.tau] - vs.eke[:-1, :, :, vs.tau])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskU[:-1, :, :],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(
flux_north,
at[:, :-1, :],
0.5
* npx.maximum(500.0, vs.K_gm[:, :-1, :] + vs.K_gm[:, 1:, :])
* (vs.eke[:, 1:, :, vs.tau] - vs.eke[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_north = update(flux_north, at[:, -1, :], 0.0)
vs.eke = update_add(
vs.eke,
at[2:-2, 2:-2, :, vs.taup1],
settings.dt_tracer
* vs.maskW[2:-2, 2:-2, :]
* (
(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
"""
add tendency due to advection
"""
if settings.enable_eke_superbee_advection:
flux_east, flux_north, flux_top = advection.adv_flux_superbee_wgrid(state, vs.eke[:, :, :, vs.tau])
if settings.enable_eke_upwind_advection:
flux_east, flux_north, flux_top = advection.adv_flux_upwind_wgrid(state, vs.eke[:, :, :, vs.tau])
if settings.enable_eke_superbee_advection or settings.enable_eke_upwind_advection:
vs.deke = update(
vs.deke,
at[2:-2, 2:-2, :, vs.tau],
vs.maskW[2:-2, 2:-2, :]
* (
-(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
- (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
vs.deke = update_add(vs.deke, at[:, :, 0, vs.tau], -flux_top[:, :, 0] / vs.dzw[0])
vs.deke = update_add(
vs.deke,
at[:, :, 1:-1, vs.tau],
-(flux_top[:, :, 1:-1] - flux_top[:, :, :-2]) / vs.dzw[npx.newaxis, npx.newaxis, 1:-1],
)
vs.deke = update_add(
vs.deke, at[:, :, -1, vs.tau], -(flux_top[:, :, -1] - flux_top[:, :, -2]) / (0.5 * vs.dzw[-1])
)
"""
Adam Bashforth time stepping
"""
vs.eke = update_add(
vs.eke,
at[:, :, :, vs.taup1],
settings.dt_tracer
* (
(1.5 + settings.AB_eps) * vs.deke[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.deke[:, :, :, vs.taum1]
),
)
conditional_outputs.update(deke=vs.deke)
return KernelOutput(eke=vs.eke, eke_diss_iw=vs.eke_diss_iw, eke_diss_tke=vs.eke_diss_tke, **conditional_outputs)
from veros.core.external.streamfunction_init import streamfunction_init # noqa: F401
from veros.core.external.solve_stream import solve_streamfunction # noqa: F401
from veros.core.external.solve_pressure import solve_pressure # noqa: F401
import scipy.ndimage
from veros import veros_routine, logger
from veros.core import utilities
from veros.core.operators import numpy as npx
# fall back to vanilla NumPy for some operations
import numpy as onp
def _compute_isleperim(kmt, enable_cyclic_x):
# TODO: remove this check after jax#6907 has landed
if enable_cyclic_x:
kmt = utilities.enforce_boundaries(kmt, enable_cyclic_x)
kmt = onp.asarray(kmt)
structure = onp.ones((3, 3)) # merge diagonally connected land masses
# find all land masses
labelled, _ = scipy.ndimage.label(kmt == 0, structure=structure)
# find and set perimeter
land_masses = labelled > 0
inner = scipy.ndimage.binary_dilation(land_masses, structure=structure)
perimeter = onp.logical_xor(inner, land_masses)
labelled[perimeter] = -1
# match wrapping periodic land masses
if enable_cyclic_x:
west_slice = onp.array(labelled[2])
east_slice = onp.array(labelled[-2])
for west_label in onp.unique(west_slice[west_slice > 0]):
east_labels = onp.unique(east_slice[west_slice == west_label])
east_labels = east_labels[~onp.isin(east_labels, [west_label, -1])]
if not east_labels.size:
# already labelled correctly
continue
assert len(onp.unique(east_labels)) == 1, (west_label, east_labels)
labelled[labelled == east_labels[0]] = west_label
# TODO: remove this check after jax#6907 has landed
if enable_cyclic_x:
labelled = utilities.enforce_boundaries(labelled, enable_cyclic_x)
labelled = onp.asarray(labelled)
# label landmasses in a way that is consistent with pyom
labels = onp.unique(labelled[labelled > 0])
label_idx = {}
for label in labels:
# find index of first island cell, scanning west to east, north to south
label_idx[label] = onp.argmax(labelled[:, ::-1].T == label)
sorted_labels = list(sorted(labels, key=lambda i: label_idx[i]))
# ensure labels are numbered consecutively
relabelled = onp.array(labelled)
for new_label, label in enumerate(sorted_labels, 1):
if label == new_label:
continue
relabelled[labelled == label] = new_label
return npx.asarray(relabelled)
@veros_routine(dist_safe=False, local_variables=("kbot", "land_map"))
def isleperim(state):
vs = state.variables
settings = state.settings
logger.debug(" Determining number of land masses")
vs.land_map = _compute_isleperim(vs.kbot, settings.enable_cyclic_x)
if vs.land_map.size < 10_000:
logger.debug(_ascii_map(vs.land_map))
def _ascii_map(boundary_map):
def _get_char(c):
if c == 0:
return "."
if c < 0:
return "#"
return str(c % 10)
boundary_map = onp.array(boundary_map)
nx, ny = boundary_map.shape
map_string = ""
linewidth = 100
iremain = nx
istart = 0
map_string += "\n"
map_string += " " * (5 + min(linewidth, nx) // 2 - 13) + "Land mass and perimeter"
map_string += "\n"
for _ in range(1, nx // linewidth + 2):
iline = min(iremain, linewidth)
iremain = iremain - iline
if iline > 0:
map_string += "\n"
map_string += "".join([f"{istart + i + 1 - 2:5d}" for i in range(1, iline + 1, 5)])
map_string += "\n"
for j in range(ny - 1, -1, -1):
map_string += f"{j:3d} "
map_string += "".join([_get_char(boundary_map[istart + i - 2, j]) for i in range(2, iline + 2)])
map_string += "\n"
map_string += "".join([f"{istart + i + 1 - 2:5d}" for i in range(1, iline + 1, 5)])
map_string += "\n"
istart = istart + iline
map_string += "\n"
return map_string
from veros.core.operators import numpy as npx
from veros import veros_kernel, runtime_state
from veros.distributed import global_sum
from veros.core.operators import update, at, for_loop
@veros_kernel(static_args=("kind"))
def line_integrals(state, uloc, vloc, kind="same"):
"""
calculate line integrals along all islands
Arguments:
kind: 'same' calculates only line integral contributions of an island with itself,
while 'full' calculates all possible pairings between all islands.
"""
vs = state.variables
nisle = state.dimensions["isle"]
ipx, ipy = runtime_state.proc_idx
if ipx == 0:
i = slice(1, -2)
ip1 = slice(2, -1)
else:
i = slice(2, -2)
ip1 = slice(3, -1)
if ipy == 0:
j = slice(1, -2)
jp1 = slice(2, -1)
else:
j = slice(2, -2)
jp1 = slice(3, -1)
east = (
vloc[i, j, :] * vs.dyu[npx.newaxis, j, npx.newaxis]
+ uloc[i, jp1, :] * vs.dxu[i, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, jp1, npx.newaxis]
)
west = (
-vloc[ip1, j, :] * vs.dyu[npx.newaxis, j, npx.newaxis]
- uloc[i, j, :] * vs.dxu[i, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, j, npx.newaxis]
)
north = (
vloc[i, j, :] * vs.dyu[npx.newaxis, j, npx.newaxis]
- uloc[i, j, :] * vs.dxu[i, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, j, npx.newaxis]
)
south = (
-vloc[ip1, j, :] * vs.dyu[npx.newaxis, j, npx.newaxis]
+ uloc[i, jp1, :] * vs.dxu[i, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, jp1, npx.newaxis]
)
if kind == "same":
east = npx.sum(east * vs.line_dir_east_mask[i, j], axis=(0, 1))
west = npx.sum(west * vs.line_dir_west_mask[i, j], axis=(0, 1))
north = npx.sum(north * vs.line_dir_north_mask[i, j], axis=(0, 1))
south = npx.sum(south * vs.line_dir_south_mask[i, j], axis=(0, 1))
return global_sum(east + west + north + south)
elif kind == "full":
isle_int = npx.empty((nisle, nisle))
def loop_body(isle, isle_int):
east_isle = npx.sum(
east[..., isle, npx.newaxis] * vs.line_dir_east_mask[i, j],
axis=(0, 1),
)
west_isle = npx.sum(
west[..., isle, npx.newaxis] * vs.line_dir_west_mask[i, j],
axis=(0, 1),
)
north_isle = npx.sum(
north[..., isle, npx.newaxis] * vs.line_dir_north_mask[i, j],
axis=(0, 1),
)
south_isle = npx.sum(
south[..., isle, npx.newaxis] * vs.line_dir_south_mask[i, j],
axis=(0, 1),
)
isle_int = update(isle_int, at[:, isle], east_isle + west_isle + north_isle + south_isle)
return isle_int
isle_int = for_loop(0, nisle, loop_body, isle_int)
return global_sum(isle_int)
else:
raise ValueError('"kind" argument must be "same" or "full"')
from veros.core.operators import update, at, numpy as npx
from veros.variables import allocate
def assemble_poisson_matrix(state):
if state.settings.enable_streamfunction:
return assemble_streamfunction_matrix(state)
else:
return assemble_pressure_matrix(state)
def assemble_pressure_matrix(state):
main_diag = allocate(state.dimensions, ("xu", "yu"), fill=1)
east_diag, west_diag, north_diag, south_diag = (allocate(state.dimensions, ("xu", "yu")) for _ in range(4))
vs = state.variables
settings = state.settings
maskM = vs.maskT[:, :, -1]
mp_i = maskM[2:-2, 2:-2] * maskM[3:-1, 2:-2]
mm_i = maskM[2:-2, 2:-2] * maskM[1:-3, 2:-2]
mp_j = maskM[2:-2, 2:-2] * maskM[2:-2, 3:-1]
mm_j = maskM[2:-2, 2:-2] * maskM[2:-2, 1:-3]
main_diag = update(
main_diag,
at[2:-2, 2:-2],
-1
* mp_i
* vs.hu[2:-2, 2:-2]
/ vs.dxu[2:-2, npx.newaxis]
/ vs.dxt[2:-2, npx.newaxis]
/ vs.cost[npx.newaxis, 2:-2] ** 2
- 1
* mm_i
* vs.hu[1:-3, 2:-2]
/ vs.dxu[1:-3, npx.newaxis]
/ vs.dxt[2:-2, npx.newaxis]
/ vs.cost[npx.newaxis, 2:-2] ** 2
- 1
* mp_j
* vs.hv[2:-2, 2:-2]
/ vs.dyu[npx.newaxis, 2:-2]
/ vs.dyt[npx.newaxis, 2:-2]
* vs.cosu[npx.newaxis, 2:-2]
/ vs.cost[npx.newaxis, 2:-2]
- 1
* mm_j
* vs.hv[2:-2, 1:-3]
/ vs.dyu[npx.newaxis, 1:-3]
/ vs.dyt[npx.newaxis, 2:-2]
* vs.cosu[npx.newaxis, 1:-3]
/ vs.cost[npx.newaxis, 2:-2]
# free surface
- 1.0 / (settings.grav * settings.dt_mom * settings.dt_tracer) * maskM[2:-2, 2:-2],
)
east_diag = update(
east_diag,
at[2:-2, 2:-2],
mp_i
* vs.hu[2:-2, 2:-2]
/ vs.dxu[2:-2, npx.newaxis]
/ vs.dxt[2:-2, npx.newaxis]
/ vs.cost[npx.newaxis, 2:-2] ** 2,
)
west_diag = update(
west_diag,
at[2:-2, 2:-2],
mm_i
* vs.hu[1:-3, 2:-2]
/ vs.dxu[1:-3, npx.newaxis]
/ vs.dxt[2:-2, npx.newaxis]
/ vs.cost[npx.newaxis, 2:-2] ** 2,
)
north_diag = update(
north_diag,
at[2:-2, 2:-2],
mp_j
* vs.hv[2:-2, 2:-2]
/ vs.dyu[npx.newaxis, 2:-2]
/ vs.dyt[npx.newaxis, 2:-2]
* vs.cosu[npx.newaxis, 2:-2]
/ vs.cost[npx.newaxis, 2:-2],
)
south_diag = update(
south_diag,
at[2:-2, 2:-2],
mm_j
* vs.hv[2:-2, 1:-3]
/ vs.dyu[npx.newaxis, 1:-3]
/ vs.dyt[npx.newaxis, 2:-2]
* vs.cosu[npx.newaxis, 1:-3]
/ vs.cost[npx.newaxis, 2:-2],
)
main_diag = main_diag * maskM
main_diag = npx.where(npx.abs(main_diag) == 0.0, 1, main_diag)
offsets = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
diags = [main_diag, east_diag, west_diag, north_diag, south_diag]
return diags, offsets, maskM
def assemble_streamfunction_matrix(state):
vs = state.variables
# assemble diagonals
main_diag = allocate(state.dimensions, ("xu", "yu"), fill=1)
east_diag, west_diag, north_diag, south_diag = (allocate(state.dimensions, ("xu", "yu")) for _ in range(4))
main_diag = update(
main_diag,
at[2:-2, 2:-2],
-vs.hvr[3:-1, 2:-2] / vs.dxu[2:-2, npx.newaxis] / vs.dxt[3:-1, npx.newaxis] / vs.cosu[npx.newaxis, 2:-2] ** 2
- vs.hvr[2:-2, 2:-2] / vs.dxu[2:-2, npx.newaxis] / vs.dxt[2:-2, npx.newaxis] / vs.cosu[npx.newaxis, 2:-2] ** 2
- vs.hur[2:-2, 2:-2]
/ vs.dyu[npx.newaxis, 2:-2]
/ vs.dyt[npx.newaxis, 2:-2]
* vs.cost[npx.newaxis, 2:-2]
/ vs.cosu[npx.newaxis, 2:-2]
- vs.hur[2:-2, 3:-1]
/ vs.dyu[npx.newaxis, 2:-2]
/ vs.dyt[npx.newaxis, 3:-1]
* vs.cost[npx.newaxis, 3:-1]
/ vs.cosu[npx.newaxis, 2:-2],
)
east_diag = update(
east_diag,
at[2:-2, 2:-2],
vs.hvr[3:-1, 2:-2] / vs.dxu[2:-2, npx.newaxis] / vs.dxt[3:-1, npx.newaxis] / vs.cosu[npx.newaxis, 2:-2] ** 2,
)
west_diag = update(
west_diag,
at[2:-2, 2:-2],
vs.hvr[2:-2, 2:-2] / vs.dxu[2:-2, npx.newaxis] / vs.dxt[2:-2, npx.newaxis] / vs.cosu[npx.newaxis, 2:-2] ** 2,
)
north_diag = update(
north_diag,
at[2:-2, 2:-2],
vs.hur[2:-2, 3:-1]
/ vs.dyu[npx.newaxis, 2:-2]
/ vs.dyt[npx.newaxis, 3:-1]
* vs.cost[npx.newaxis, 3:-1]
/ vs.cosu[npx.newaxis, 2:-2],
)
south_diag = update(
south_diag,
at[2:-2, 2:-2],
vs.hur[2:-2, 2:-2]
/ vs.dyu[npx.newaxis, 2:-2]
/ vs.dyt[npx.newaxis, 2:-2]
* vs.cost[npx.newaxis, 2:-2]
/ vs.cosu[npx.newaxis, 2:-2],
)
main_diag = main_diag * vs.isle_boundary_mask
main_diag = npx.where(main_diag == 0.0, 1.0, main_diag)
offsets = [(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
diags = [
main_diag,
east_diag * vs.isle_boundary_mask,
west_diag * vs.isle_boundary_mask,
north_diag * vs.isle_boundary_mask,
south_diag * vs.isle_boundary_mask,
]
return diags, offsets, vs.isle_boundary_mask
"""
solve two dimensional Possion equation
A * dpsi = forc, where A = nabla_h^2
with Neumann boundary conditions
used for surface pressure or free surface
method same as pressure method in MITgcm
"""
from veros import veros_routine
from veros.routines import veros_kernel
from veros.state import KernelOutput
from veros.variables import allocate
from veros.core import utilities as mainutils
from veros.core.operators import update, update_add, at, for_loop
from veros.core.operators import numpy as npx
from veros.core.external.solvers import get_linear_solver
@veros_routine
def solve_pressure(state):
vs = state.variables
state_update, forc = prepare_forcing(state)
vs.update(state_update)
linear_solver = get_linear_solver(state)
linear_sol = linear_solver.solve(state, forc, vs.psi[..., vs.taup1])
linear_sol = mainutils.enforce_boundaries(linear_sol, state.settings.enable_cyclic_x)
if vs.itt == 0:
vs.psi = update(vs.psi, at[...], linear_sol[..., npx.newaxis])
else:
vs.psi = update(vs.psi, at[..., vs.taup1], linear_sol)
vs.update(barotropic_velocity_update(state))
@veros_kernel
def prepare_forcing(state):
vs = state.variables
settings = state.settings
# hydrostatic pressure
vs.p_hydro = update(
vs.p_hydro,
at[:, :, -1],
0.5 * vs.rho[:, :, -1, vs.tau] * settings.grav / settings.rho_0 * vs.dzw[-1] * vs.maskT[:, :, -1],
)
def compute_p_hydro(k_inv, p_hydro):
k = settings.nz - k_inv - 1
p_hydro = update(
p_hydro,
at[..., k],
vs.maskT[:, :, k]
* (
p_hydro[:, :, k + 1]
+ 0.5
* vs.dzw[k]
* settings.grav
/ settings.rho_0
* (vs.rho[:, :, k + 1, vs.tau] + vs.rho[:, :, k, vs.tau])
),
)
return p_hydro
vs.p_hydro = for_loop(1, settings.nz, compute_p_hydro, vs.p_hydro)
# add hydrostatic pressure gradient
vs.du = update_add(
vs.du,
at[2:-2, 2:-2, :, vs.tau],
-(vs.p_hydro[3:-1, 2:-2, :] - vs.p_hydro[2:-2, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxu[2:-2, npx.newaxis, npx.newaxis])
* vs.maskU[2:-2, 2:-2, :],
)
vs.dv = update_add(
vs.dv,
at[2:-2, 2:-2, :, vs.tau],
-(vs.p_hydro[2:-2, 3:-1, :] - vs.p_hydro[2:-2, 2:-2, :])
/ vs.dyu[npx.newaxis, 2:-2, npx.newaxis]
* vs.maskV[2:-2, 2:-2, :],
)
# Integrate forward in time
vs.u = update(
vs.u,
at[:, :, :, vs.taup1],
vs.u[:, :, :, vs.tau]
+ settings.dt_mom
* (
vs.du_mix
+ (1.5 + settings.AB_eps) * vs.du[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.du[:, :, :, vs.taum1]
)
* vs.maskU,
)
vs.v = update(
vs.v,
at[:, :, :, vs.taup1],
vs.v[:, :, :, vs.tau]
+ settings.dt_mom
* (
vs.dv_mix
+ (1.5 + settings.AB_eps) * vs.dv[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.dv[:, :, :, vs.taum1]
)
* vs.maskV,
)
# forcing for surface pressure
uloc = allocate(state.dimensions, ("xt", "yt"))
vloc = allocate(state.dimensions, ("xt", "yt"))
uloc = update(
uloc,
at[2:-2, 2:-2],
npx.sum((vs.u[2:-2, 2:-2, :, vs.taup1]) * vs.maskU[2:-2, 2:-2, :] * vs.dzt, axis=(2,)) / settings.dt_mom,
)
vloc = update(
vloc,
at[2:-2, 2:-2],
npx.sum((vs.v[2:-2, 2:-2, :, vs.taup1]) * vs.maskV[2:-2, 2:-2, :] * vs.dzt, axis=(2,)) / settings.dt_mom,
)
uloc = mainutils.enforce_boundaries(uloc, settings.enable_cyclic_x)
vloc = mainutils.enforce_boundaries(vloc, settings.enable_cyclic_x)
forc = allocate(state.dimensions, ("xt", "yt"))
forc = update(
forc,
at[2:-2, 2:-2],
(uloc[2:-2, 2:-2] - uloc[1:-3, 2:-2]) / (vs.cost[2:-2] * vs.dxt[2:-2, npx.newaxis])
+ (vs.cosu[2:-2] * vloc[2:-2, 2:-2] - vs.cosu[1:-3] * vloc[2:-2, 1:-3]) / (vs.cost[2:-2] * vs.dyt[2:-2])
# free surface
- vs.psi[2:-2, 2:-2, vs.tau]
/ (settings.grav * settings.dt_mom * settings.dt_tracer)
* vs.maskT[2:-2, 2:-2, -1],
)
# first guess
vs.psi = update(vs.psi, at[:, :, vs.taup1], 2 * vs.psi[:, :, vs.tau] - vs.psi[:, :, vs.taum1])
return KernelOutput(du=vs.du, dv=vs.dv, u=vs.u, v=vs.v, psi=vs.psi, p_hydro=vs.p_hydro), forc
@veros_kernel
def barotropic_velocity_update(state):
"""
solve for surface pressure
"""
vs = state.variables
settings = state.settings
vs.u = update_add(
vs.u,
at[2:-2, 2:-2, :, vs.taup1],
-settings.dt_mom
* (vs.psi[3:-1, 2:-2, vs.taup1, npx.newaxis] - vs.psi[2:-2, 2:-2, vs.taup1, npx.newaxis])
/ (vs.dxu[2:-2, npx.newaxis, npx.newaxis] * vs.cost[2:-2, npx.newaxis])
* vs.maskU[2:-2, 2:-2, :],
)
vs.v = update_add(
vs.v,
at[2:-2, 2:-2, :, vs.taup1],
-settings.dt_mom
* (vs.psi[2:-2, 3:-1, vs.taup1, npx.newaxis] - vs.psi[2:-2, 2:-2, vs.taup1, npx.newaxis])
/ vs.dyu[npx.newaxis, 2:-2, npx.newaxis]
* vs.maskV[2:-2, 2:-2, :],
)
vs.ssh = vs.psi[..., vs.tau] / settings.grav
return KernelOutput(u=vs.u, v=vs.v, ssh=vs.ssh)
"""
solve two dimensional Possion equation
A * dpsi = forc, where A = nabla_h^2
with Dirichlet boundary conditions
used for streamfunction
"""
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core import utilities as mainutils
from veros.core.operators import update, update_add, at, for_loop
from veros.core.operators import numpy as npx
from veros.core.external import line_integrals
from veros.core.external.solvers import get_linear_solver
@veros_routine
def solve_streamfunction(state):
vs = state.variables
state_update, (forc, uloc, vloc) = prepare_forcing(state)
vs.update(state_update)
linear_solver = get_linear_solver(state)
linear_sol = linear_solver.solve(state, forc, vs.dpsi[..., vs.taup1])
vs.dpsi = update(vs.dpsi, at[..., vs.taup1], linear_sol)
vs.update(barotropic_velocity_update(state, uloc=uloc, vloc=vloc))
@veros_kernel
def prepare_forcing(state):
vs = state.variables
settings = state.settings
# hydrostatic pressure
vs.p_hydro = update(
vs.p_hydro,
at[:, :, -1],
0.5 * vs.rho[:, :, -1, vs.tau] * settings.grav / settings.rho_0 * vs.dzw[-1] * vs.maskT[:, :, -1],
)
def compute_p_hydro(k_inv, p_hydro):
k = settings.nz - k_inv - 1
p_hydro = update(
p_hydro,
at[..., k],
vs.maskT[:, :, k]
* (
p_hydro[:, :, k + 1]
+ 0.5
* vs.dzw[k]
* settings.grav
/ settings.rho_0
* (vs.rho[:, :, k + 1, vs.tau] + vs.rho[:, :, k, vs.tau])
),
)
return p_hydro
vs.p_hydro = for_loop(1, settings.nz, compute_p_hydro, vs.p_hydro)
# add hydrostatic pressure gradient
vs.du = update_add(
vs.du,
at[2:-2, 2:-2, :, vs.tau],
-(vs.p_hydro[3:-1, 2:-2, :] - vs.p_hydro[2:-2, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxu[2:-2, npx.newaxis, npx.newaxis])
* vs.maskU[2:-2, 2:-2, :],
)
vs.dv = update_add(
vs.dv,
at[2:-2, 2:-2, :, vs.tau],
-(vs.p_hydro[2:-2, 3:-1, :] - vs.p_hydro[2:-2, 2:-2, :])
/ vs.dyu[npx.newaxis, 2:-2, npx.newaxis]
* vs.maskV[2:-2, 2:-2, :],
)
# forcing for barotropic streamfunction
uloc = npx.sum((vs.du[:, :, :, vs.tau] + vs.du_mix) * vs.maskU * vs.dzt, axis=(2,)) * vs.hur
vloc = npx.sum((vs.dv[:, :, :, vs.tau] + vs.dv_mix) * vs.maskV * vs.dzt, axis=(2,)) * vs.hvr
uloc = mainutils.enforce_boundaries(uloc, settings.enable_cyclic_x)
vloc = mainutils.enforce_boundaries(vloc, settings.enable_cyclic_x)
forc = allocate(state.dimensions, ("xt", "yt"))
forc = update(
forc,
at[2:-2, 2:-2],
(vloc[3:-1, 2:-2] - vloc[2:-2, 2:-2]) / (vs.cosu[2:-2] * vs.dxu[2:-2, npx.newaxis])
- (vs.cost[3:-1] * uloc[2:-2, 3:-1] - vs.cost[2:-2] * uloc[2:-2, 2:-2]) / (vs.cosu[2:-2] * vs.dyu[2:-2]),
)
# solve for interior streamfunction
vs.dpsi = update(vs.dpsi, at[:, :, vs.taup1], 2 * vs.dpsi[:, :, vs.tau] - vs.dpsi[:, :, vs.taum1])
return KernelOutput(du=vs.du, dv=vs.dv, dpsi=vs.dpsi, p_hydro=vs.p_hydro), (forc, uloc, vloc)
@veros_kernel
def barotropic_velocity_update(state, uloc, vloc):
"""
solve for barotropic streamfunction
"""
vs = state.variables
settings = state.settings
vs.dpsi = update(
vs.dpsi, at[:, :, vs.taup1], mainutils.enforce_boundaries(vs.dpsi[:, :, vs.taup1], settings.enable_cyclic_x)
)
line_forc = allocate(state.dimensions, ("isle",))
if state.dimensions["isle"] > 1:
# calculate island integrals of forcing, keep psi constant on island 1
line_forc = update(
line_forc,
at[1:],
line_integrals.line_integrals(state, uloc=uloc[..., npx.newaxis], vloc=vloc[..., npx.newaxis], kind="same")[
1:
],
)
# calculate island integrals of interior streamfunction
uloc = update(uloc, at[...], 0.0)
vloc = update(vloc, at[...], 0.0)
uloc = update(
uloc,
at[1:, 1:],
-1
* vs.maskU[1:, 1:, -1]
* (vs.dpsi[1:, 1:, vs.taup1] - vs.dpsi[1:, :-1, vs.taup1])
/ vs.dyt[npx.newaxis, 1:]
* vs.hur[1:, 1:],
)
vloc = update(
vloc,
at[1:, 1:],
vs.maskV[1:, 1:, -1]
* (vs.dpsi[1:, 1:, vs.taup1] - vs.dpsi[:-1, 1:, vs.taup1])
/ (vs.cosu[npx.newaxis, 1:] * vs.dxt[1:, npx.newaxis])
* vs.hvr[1:, 1:],
)
line_forc = update_add(
line_forc,
at[1:],
-line_integrals.line_integrals(
state, uloc=uloc[..., npx.newaxis], vloc=vloc[..., npx.newaxis], kind="same"
)[1:],
)
# solve for time dependent boundary values
vs.dpsin = update(vs.dpsin, at[1:, vs.tau], npx.linalg.solve(vs.line_psin[1:, 1:], line_forc[1:]))
# integrate barotropic and baroclinic velocity forward in time
vs.psi = update(
vs.psi,
at[:, :, vs.taup1],
vs.psi[:, :, vs.tau]
+ settings.dt_mom
* ((1.5 + settings.AB_eps) * vs.dpsi[:, :, vs.taup1] - (0.5 + settings.AB_eps) * vs.dpsi[:, :, vs.tau]),
)
vs.psi = update_add(
vs.psi,
at[:, :, vs.taup1],
settings.dt_mom
* npx.sum(
((1.5 + settings.AB_eps) * vs.dpsin[1:, vs.tau] - (0.5 + settings.AB_eps) * vs.dpsin[1:, vs.taum1])
* vs.psin[:, :, 1:],
axis=2,
),
)
vs.u = update(
vs.u,
at[:, :, :, vs.taup1],
vs.u[:, :, :, vs.tau]
+ settings.dt_mom
* (
vs.du_mix
+ (1.5 + settings.AB_eps) * vs.du[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.du[:, :, :, vs.taum1]
)
* vs.maskU,
)
vs.v = update(
vs.v,
at[:, :, :, vs.taup1],
vs.v[:, :, :, vs.tau]
+ settings.dt_mom
* (
vs.dv_mix
+ (1.5 + settings.AB_eps) * vs.dv[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.dv[:, :, :, vs.taum1]
)
* vs.maskV,
)
# subtract incorrect vertical mean from baroclinic velocity
uloc = npx.sum(vs.u[:, :, :, vs.taup1] * vs.maskU * vs.dzt, axis=2)
vloc = npx.sum(vs.v[:, :, :, vs.taup1] * vs.maskV * vs.dzt, axis=2)
vs.u = update_add(vs.u, at[:, :, :, vs.taup1], -uloc[:, :, npx.newaxis] * vs.maskU * vs.hur[:, :, npx.newaxis])
vs.v = update_add(vs.v, at[:, :, :, vs.taup1], -vloc[:, :, npx.newaxis] * vs.maskV * vs.hvr[:, :, npx.newaxis])
# add barotropic mode to baroclinic velocity
vs.u = update_add(
vs.u,
at[2:-2, 2:-2, :, vs.taup1],
-1
* vs.maskU[2:-2, 2:-2, :]
* (vs.psi[2:-2, 2:-2, vs.taup1, npx.newaxis] - vs.psi[2:-2, 1:-3, vs.taup1, npx.newaxis])
/ vs.dyt[npx.newaxis, 2:-2, npx.newaxis]
* vs.hur[2:-2, 2:-2, npx.newaxis],
)
vs.v = update_add(
vs.v,
at[2:-2, 2:-2, :, vs.taup1],
vs.maskV[2:-2, 2:-2, :]
* (vs.psi[2:-2, 2:-2, vs.taup1, npx.newaxis] - vs.psi[1:-3, 2:-2, vs.taup1, npx.newaxis])
/ (vs.cosu[2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
* vs.hvr[2:-2, 2:-2][:, :, npx.newaxis],
)
return KernelOutput(u=vs.u, v=vs.v, psi=vs.psi, dpsi=vs.dpsi, dpsin=vs.dpsin)
import functools
from veros import runtime_settings as rs, runtime_state as rst, logger
def memoize(func):
func.cache = {}
@functools.wraps(func)
def inner(*args):
if args not in func.cache:
func.cache[args] = func(*args)
return func.cache[args]
return inner
def _get_solver_class():
ls = rs.linear_solver
def _get_best_solver():
if rst.proc_num > 1:
try:
from veros.core.external.solvers.petsc_ import PETScSolver
except ImportError:
logger.warning("PETSc linear solver not available, falling back to SciPy")
else:
return PETScSolver
if rs.backend == "jax" and rs.device == "gpu" and rs.float_type == "float64":
from veros.core.external.solvers.scipy_jax import JAXSciPySolver
return JAXSciPySolver
from veros.core.external.solvers.scipy import SciPySolver
return SciPySolver
if ls == "best":
return _get_best_solver()
elif ls == "petsc":
from veros.core.external.solvers.petsc_ import PETScSolver
return PETScSolver
elif ls == "scipy":
from veros.core.external.solvers.scipy import SciPySolver
return SciPySolver
elif ls == "scipy_jax":
from veros.core.external.solvers.scipy_jax import JAXSciPySolver
return JAXSciPySolver
raise ValueError(f"unrecognized linear solver {ls}")
@memoize
def get_linear_solver(state):
logger.debug("Initializing linear solver")
SolverClass = _get_solver_class()
return SolverClass(state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment