Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
from abc import abstractmethod, ABCMeta
class LinearSolver(metaclass=ABCMeta):
@abstractmethod
def __init__(self, vs):
pass
@abstractmethod
def solve(self, vs, rhs, x0, boundary_val=None):
pass
import os
from petsc4py import PETSc
import numpy as onp
from veros import logger, veros_kernel, runtime_settings as rs, runtime_state as rst
from veros.core import utilities
from veros.core.external.solvers.base import LinearSolver
from veros.core.operators import numpy as npx, update, update_add, at, flush
from veros.core.external.poisson_matrix import assemble_poisson_matrix
STREAM_OPTIONS = {
"solver_type": "bcgs",
"atol": 1e-24,
"rtol": 1e-14,
"max_it": 1000,
"PC_type": "gamg",
"pc_options": {
"pc_gamg_type": "agg",
"pc_gamg_reuse_interpolation": True,
"pc_gamg_threshold": 1e-4,
"pc_gamg_sym_graph": True,
"pc_gamg_agg_nsmooths": 2,
"mg_levels_pc_type": "jacobi",
},
}
PRESSURE_OPTIONS = {
"solver_type": "bcgs",
"atol": 1e-24,
"rtol": 1e-14,
"max_it": 1000,
"PC_type": "gamg",
"pc_options": {
"pc_gamg_type": "agg",
"pc_gamg_reuse_interpolation": True,
"pc_gamg_threshold": 1e-4,
"pc_gamg_sym_graph": True,
"pc_gamg_agg_nsmooths": 2,
"mg_levels_pc_type": "jacobi",
},
}
class PETScSolver(LinearSolver):
def __init__(self, state):
if rst.proc_num > 1 and rs.device == "cpu" and "OMP_NUM_THREADS" not in os.environ:
logger.warning(
"Environment variable OMP_NUM_THREADS is not set, which can lead to severely "
"degraded performance when MPI is used."
)
settings = state.settings
if settings.enable_streamfunction:
options = STREAM_OPTIONS
else:
options = PRESSURE_OPTIONS
if settings.enable_cyclic_x:
boundary_type = ("periodic", "ghosted")
else:
boundary_type = ("ghosted", "ghosted")
self._da = PETSc.DMDA().create(
[settings.nx, settings.ny],
stencil_width=1,
stencil_type="star",
comm=rs.mpi_comm,
proc_sizes=rs.num_proc,
boundary_type=boundary_type,
ownership_ranges=[
(settings.nx // rs.num_proc[0],) * rs.num_proc[0],
(settings.ny // rs.num_proc[1],) * rs.num_proc[1],
],
)
if rs.device == "gpu":
self._da.setVecType("cuda")
self._da.setMatType("aijcusparse")
self._matrix, self._boundary_mask = self._assemble_poisson_matrix(state)
petsc_options = PETSc.Options()
# setup krylov method
self._ksp = PETSc.KSP()
self._ksp.create(self._da.comm)
self._ksp.setOperators(self._matrix)
self._ksp.setType(options["solver_type"])
self._ksp.setTolerances(atol=options["atol"], rtol=options["rtol"], max_it=options["max_it"])
# preconditioner
self._ksp.getPC().setType(options["PC_type"])
for key in options["pc_options"]:
petsc_options[key] = options["pc_options"][key]
if rs.petsc_options:
petsc_options.insertString(rs.petsc_options)
self._ksp.setFromOptions()
self._ksp.getPC().setFromOptions()
self._rhs_petsc = self._da.createGlobalVec()
self._sol_petsc = self._da.createGlobalVec()
def _petsc_solver(self, rhs, x0):
# hangs on multi-GPU without this
flush()
self._da.getVecArray(self._rhs_petsc)[...] = rhs[2:-2, 2:-2]
self._da.getVecArray(self._sol_petsc)[...] = x0[2:-2, 2:-2]
self._ksp.solve(self._rhs_petsc, self._sol_petsc)
info = self._ksp.getConvergedReason()
iterations = self._ksp.getIterationNumber()
if info < 0:
logger.warning(f"Streamfunction solver did not converge after {iterations} iterations (error code: {info})")
if rs.monitor_streamfunction_residual:
# re-use rhs vector to store residual
rhs_norm = self._rhs_petsc.norm(PETSc.NormType.NORM_2)
self._matrix.multAdd(self._sol_petsc, -self._rhs_petsc, self._rhs_petsc)
residual_norm = self._rhs_petsc.norm(PETSc.NormType.NORM_2)
rel_residual = residual_norm / max(rhs_norm, 1e-22)
if rel_residual > 1e-8:
logger.warning(
f"Streamfunction solver did not achieve required precision (rel. residual: {rel_residual:.2e})"
)
return npx.asarray(self._da.getVecArray(self._sol_petsc)[...])
def solve(self, state, rhs, x0, boundary_val=None):
"""
Arguments:
rhs: Right-hand side vector
x0: Initial guess
boundary_val: Array containing values to set on boundary elements. Defaults to `x0`.
"""
rhs, x0 = prepare_solver_inputs(state, rhs, x0, boundary_val, self._boundary_mask, self._boundary_fac)
linear_solution = self._petsc_solver(rhs, x0)
return update(rhs, at[2:-2, 2:-2], linear_solution)
def _assemble_poisson_matrix(self, state):
diags, offsets, boundary_mask = assemble_poisson_matrix(state)
diags = onp.asarray(diags, dtype=onp.float64)
diags = diags[:, 2:-2, 2:-2]
row = PETSc.Mat.Stencil()
col = PETSc.Mat.Stencil()
(i0, i1), (j0, j1) = self._da.getRanges()
matrix = self._da.getMatrix()
for j in range(j0, j1):
for i in range(i0, i1):
iloc, jloc = i % (state.settings.nx // rs.num_proc[0]), j % (state.settings.ny // rs.num_proc[1])
row.index = (i, j)
for diag, offset in zip(diags, offsets):
io, jo = (i + offset[0], j + offset[1])
col.index = (io, jo)
matrix.setValueStencil(row, col, diag[iloc, jloc])
matrix.assemble()
self._boundary_fac = {
"east": npx.asarray(diags[1][-1, :]),
"west": npx.asarray(diags[2][0, :]),
"north": npx.asarray(diags[3][:, -1]),
"south": npx.asarray(diags[4][:, 0]),
}
return matrix, boundary_mask
@veros_kernel
def prepare_solver_inputs(state, rhs, x0, boundary_val, boundary_mask, boundary_fac):
settings = state.settings
if boundary_val is None:
boundary_val = x0
x0 = utilities.enforce_boundaries(x0, settings.enable_cyclic_x)
rhs = npx.where(boundary_mask, rhs, boundary_val) # set right hand side on boundaries
if settings.enable_streamfunction:
# add dirichlet BC to rhs
if not settings.enable_cyclic_x:
if rst.proc_idx[0] == rs.num_proc[0] - 1:
rhs = update_add(rhs, at[-3, 2:-2], -rhs[-2, 2:-2] * boundary_fac["east"])
if rst.proc_idx[0] == 0:
rhs = update_add(rhs, at[2, 2:-2], -rhs[1, 2:-2] * boundary_fac["west"])
if rst.proc_idx[1] == rs.num_proc[1] - 1:
rhs = update_add(rhs, at[2:-2, -3], -rhs[2:-2, -2] * boundary_fac["north"])
if rst.proc_idx[1] == 0:
rhs = update_add(rhs, at[2:-2, 2], -rhs[2:-2, 1] * boundary_fac["south"])
return rhs, x0
import numpy as onp
import scipy.sparse
import scipy.sparse.linalg as spalg
from veros import logger, veros_kernel, veros_routine, distributed, runtime_state as rst
from veros.variables import allocate
from veros.core.operators import update, at, numpy as npx
from veros.core.external.solvers.base import LinearSolver
from veros.core.external.poisson_matrix import assemble_poisson_matrix
class SciPySolver(LinearSolver):
@veros_routine(
local_variables=(
"hu",
"hv",
"hvr",
"hur",
"dxu",
"dxt",
"dyu",
"dyt",
"cosu",
"cost",
"isle_boundary_mask",
"maskT",
),
dist_safe=False,
)
def __init__(self, state):
self._matrix, self._boundary_mask = self._assemble_poisson_matrix(state)
jacobi_precon = self._jacobi_preconditioner(state, self._matrix)
self._matrix = jacobi_precon * self._matrix
self._rhs_scale = jacobi_precon.diagonal()
self._extra_args = {}
logger.info("Computing ILU preconditioner...")
ilu_preconditioner = spalg.spilu(self._matrix.tocsc(), drop_tol=1e-6, fill_factor=100)
self._extra_args["M"] = spalg.LinearOperator(self._matrix.shape, ilu_preconditioner.solve)
def _scipy_solver(self, state, rhs, x0, boundary_val):
orig_shape = x0.shape
orig_dtype = x0.dtype
rhs = npx.where(self._boundary_mask, rhs, boundary_val) # set right hand side on boundaries
rhs = onp.asarray(rhs.reshape(-1) * self._rhs_scale, dtype="float64")
x0 = onp.asarray(x0.reshape(-1), dtype="float64")
linear_solution, info = spalg.bicgstab(
self._matrix,
rhs,
x0=x0,
atol=1e-8,
tol=0,
maxiter=1000,
**self._extra_args,
)
if info > 0:
logger.warning("Streamfunction solver did not converge after {} iterations", info)
return npx.asarray(linear_solution, dtype=orig_dtype).reshape(orig_shape)
def solve(self, state, rhs, x0, boundary_val=None):
"""
Main solver for streamfunction. Solves a 2D Poisson equation. Uses scipy.sparse.linalg
linear solvers.
Arguments:
rhs: Right-hand side vector
x0: Initial guess
boundary_val: Array containing values to set on boundary elements. Defaults to `x0`.
"""
rhs_global, x0_global, boundary_val = gather_variables(state, rhs, x0, boundary_val)
if rst.proc_rank == 0:
linear_solution = self._scipy_solver(state, rhs_global, x0_global, boundary_val=boundary_val)
else:
linear_solution = npx.empty_like(rhs)
return scatter_variables(state, linear_solution)
@staticmethod
def _jacobi_preconditioner(state, matrix):
"""
Construct a simple Jacobi preconditioner
"""
settings = state.settings
eps = 1e-20
precon = allocate(state.dimensions, ("xu", "yu"), fill=1, local=False)
diag = npx.reshape(matrix.diagonal().copy(), (settings.nx + 4, settings.ny + 4))[2:-2, 2:-2]
precon = update(precon, at[2:-2, 2:-2], npx.where(npx.abs(diag) > eps, 1.0 / (diag + eps), 1.0))
precon = onp.asarray(precon)
return scipy.sparse.dia_matrix((precon.reshape(-1), 0), shape=(precon.size, precon.size)).tocsr()
@staticmethod
def _assemble_poisson_matrix(state):
settings = state.settings
diags, offsets, boundary_mask = assemble_poisson_matrix(state)
# flatten offsets (as expected by scipy.sparse)
offsets = tuple(-dx * diags[0].shape[1] - dy for dx, dy in offsets)
if settings.enable_cyclic_x:
# add cyclic boundary conditions as additional matrix diagonals
# (only works in single-process mode)
wrap_diag_east, wrap_diag_west = (allocate(state.dimensions, ("xu", "yu"), local=False) for _ in range(2))
wrap_diag_east = update(wrap_diag_east, at[2, 2:-2], diags[2][2, 2:-2] * boundary_mask[2, 2:-2])
wrap_diag_west = update(wrap_diag_west, at[-3, 2:-2], diags[1][-3, 2:-2] * boundary_mask[-3, 2:-2])
diags[2] = update(diags[2], at[2, 2:-2], 0.0)
diags[1] = update(diags[1], at[-3, 2:-2], 0.0)
offsets += (-diags[0].shape[1] * (settings.nx - 1), diags[0].shape[1] * (settings.nx - 1))
diags += (wrap_diag_east, wrap_diag_west)
diags = tuple(onp.asarray(diag.reshape(-1)) for diag in (diags))
matrix = scipy.sparse.dia_matrix(
(diags, offsets),
shape=(diags[0].size, diags[0].size),
dtype="float64",
).T.tocsr()
return matrix, boundary_mask
@veros_kernel
def gather_variables(state, rhs, x0, boundary_val):
rhs_global = distributed.gather(rhs, state.dimensions, ("xt", "yt"))
x0_global = distributed.gather(x0, state.dimensions, ("xt", "yt"))
if boundary_val is None:
boundary_val = x0_global
else:
boundary_val = distributed.gather(boundary_val, state.dimensions, ("xt", "yt"))
return rhs_global, x0_global, boundary_val
@veros_kernel
def scatter_variables(state, linear_solution):
return distributed.scatter(linear_solution, state.dimensions, ("xt", "yt"))
from veros import distributed, veros_routine, veros_kernel, runtime_state as rst
from veros.variables import allocate
from veros.core.operators import update, update_add, at, numpy as npx
from veros.core.external.solvers.base import LinearSolver
from veros.core.external.poisson_matrix import assemble_poisson_matrix
@veros_kernel(static_args=("solve_fun",))
def solve_kernel(state, rhs, x0, boundary_val, solve_fun):
rhs_global = distributed.gather(rhs, state.dimensions, ("xt", "yt"))
x0_global = distributed.gather(x0, state.dimensions, ("xt", "yt"))
if boundary_val is None:
boundary_val_global = x0_global
else:
boundary_val_global = distributed.gather(boundary_val, state.dimensions, ("xt", "yt"))
if rst.proc_rank == 0:
linear_solution = solve_fun(rhs_global, x0_global, boundary_val_global)
else:
linear_solution = npx.empty_like(rhs)
return distributed.scatter(linear_solution, state.dimensions, ("xt", "yt"))
class JAXSciPySolver(LinearSolver):
@veros_routine(
local_variables=(
"hu",
"hv",
"hvr",
"hur",
"dxu",
"dxt",
"dyu",
"dyt",
"cosu",
"cost",
"isle_boundary_mask",
"maskT",
),
dist_safe=False,
)
def __init__(self, state):
from jax.scipy.sparse.linalg import bicgstab
matrix_diags, offsets, boundary_mask = self._assemble_poisson_matrix(state)
jacobi_precon = self._jacobi_preconditioner(state, matrix_diags)
matrix_diags = tuple(jacobi_precon * diag for diag in matrix_diags)
@veros_kernel
def linear_solve(rhs, x0, boundary_val):
rhs = npx.where(boundary_mask, rhs, boundary_val) # set right hand side on boundaries
def matmul(rhs):
nx, ny = rhs.shape
res = npx.zeros_like(rhs)
for diag, (di, dj) in zip(matrix_diags, offsets):
assert diag.shape == (nx, ny)
i_s = min(max(di, 0), nx - 1)
i_e = min(max(nx + di, 1), nx)
j_s = min(max(dj, 0), ny - 1)
j_e = min(max(ny + dj, 1), ny)
i_s_inv = nx - i_e
i_e_inv = nx - i_s
j_s_inv = ny - j_e
j_e_inv = ny - j_s
res = update_add(
res,
at[i_s_inv:i_e_inv, j_s_inv:j_e_inv],
diag[i_s_inv:i_e_inv, j_s_inv:j_e_inv] * rhs[i_s:i_e, j_s:j_e],
)
return res
linear_solution, _ = bicgstab(
matmul,
rhs * self._rhs_scale,
x0=x0,
tol=0,
atol=1e-8,
maxiter=10_000,
)
return linear_solution
self._linear_solve = linear_solve
self._rhs_scale = jacobi_precon
def solve(self, state, rhs, x0, boundary_val=None):
"""
Main solver for streamfunction. Solves a 2D Poisson equation. Uses jax.scipy.sparse.linalg
linear solvers.
Arguments:
rhs: Right-hand side vector
x0: Initial guess
boundary_val: Array containing values to set on boundary elements. Defaults to `x0`.
"""
if rst.proc_rank == 0:
linear_solve = self._linear_solve
else:
linear_solve = None
return solve_kernel(state, rhs, x0, boundary_val, linear_solve)
@staticmethod
def _jacobi_preconditioner(state, matrix_diags):
"""
Construct a simple Jacobi preconditioner
"""
eps = 1e-20
precon = allocate(state.dimensions, ("xu", "yu"), fill=1, local=False)
main_diag = matrix_diags[0][2:-2, 2:-2]
precon = update(precon, at[2:-2, 2:-2], npx.where(npx.abs(main_diag) > eps, 1.0 / (main_diag + eps), 1.0))
return precon
@staticmethod
def _assemble_poisson_matrix(state):
settings = state.settings
matrix_diags, offsets, boundary_mask = assemble_poisson_matrix(state)
if settings.enable_cyclic_x:
wrap_diag_east, wrap_diag_west = (allocate(state.dimensions, ("xu", "yu"), local=False) for _ in range(2))
wrap_diag_east = update(wrap_diag_east, at[2, 2:-2], matrix_diags[2][2, 2:-2] * boundary_mask[2, 2:-2])
wrap_diag_west = update(wrap_diag_west, at[-3, 2:-2], matrix_diags[1][-3, 2:-2] * boundary_mask[-3, 2:-2])
matrix_diags[2] = update(matrix_diags[2], at[2, 2:-2], 0.0)
matrix_diags[1] = update(matrix_diags[1], at[-3, 2:-2], 0.0)
offsets += ((settings.nx - 1, 0), (-settings.nx + 1, 0))
matrix_diags += (wrap_diag_east, wrap_diag_west)
return matrix_diags, offsets, boundary_mask
from veros import logger, veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.distributed import global_max
from veros.core import utilities as mainutils
from veros.core.operators import numpy as npx, update, at
from veros.core.external import island, line_integrals, solve_stream
from veros.core.external.solvers import get_linear_solver
@veros_routine
def get_isleperim(state):
"""
preprocess land map using MOMs algorithm for B-grid to determine number of islands
"""
from veros.state import resize_dimension
vs = state.variables
island.isleperim(state)
# now that we know the number of islands we can resize
# all arrays depending on that
nisle = int(global_max(npx.max(vs.land_map)))
resize_dimension(state, "isle", nisle)
vs.isle = npx.arange(nisle)
@veros_routine
def streamfunction_init(state):
"""
prepare for island integrals
"""
vs = state.variables
settings = state.settings
logger.info("Initializing streamfunction method")
get_isleperim(state)
vs.update(boundary_masks(state))
# populate linear solver cache
linear_solver = get_linear_solver(state)
"""
precalculate time independent boundary components of streamfunction
"""
forc = allocate(state.dimensions, ("xt", "yt"))
vs.psin = update(vs.psin, at[...], vs.maskZ[..., -1, npx.newaxis])
for isle in range(state.dimensions["isle"]):
logger.info(f" Solving for boundary contribution by island {isle:d}")
isle_boundary = (
vs.line_dir_east_mask[..., isle]
| vs.line_dir_west_mask[..., isle]
| vs.line_dir_north_mask[..., isle]
| vs.line_dir_south_mask[..., isle]
)
isle_sol = linear_solver.solve(state, forc, vs.psin[:, :, isle], boundary_val=isle_boundary)
vs.psin = update(vs.psin, at[:, :, isle], isle_sol)
vs.psin = mainutils.enforce_boundaries(vs.psin, settings.enable_cyclic_x)
line_psin_out = island_integrals(state)
vs.update(line_psin_out)
"""
take care of initial velocity
"""
# transfer initial velocity to tendency
vs.du = update(vs.du, at[..., vs.tau], vs.u[..., vs.tau] / settings.dt_mom / (1.5 + settings.AB_eps))
vs.dv = update(vs.dv, at[..., vs.tau], vs.v[..., vs.tau] / settings.dt_mom / (1.5 + settings.AB_eps))
vs.u = update(vs.u, at[...], 0)
vs.v = update(vs.v, at[...], 0)
# run streamfunction solver to determine initial barotropic and baroclinic modes
solve_stream.solve_streamfunction(state)
vs.psi = update(vs.psi, at[...], vs.psi[..., vs.taup1, npx.newaxis])
vs.u = update(
vs.u, at[...], mainutils.enforce_boundaries(vs.u[..., vs.taup1, npx.newaxis], settings.enable_cyclic_x)
)
vs.v = update(
vs.v, at[...], mainutils.enforce_boundaries(vs.v[..., vs.taup1, npx.newaxis], settings.enable_cyclic_x)
)
vs.du = update(vs.du, at[..., vs.tau], 0)
vs.dv = update(vs.dv, at[..., vs.tau], 0)
@veros_kernel
def island_integrals(state):
"""
precalculate time independent island integrals
"""
vs = state.variables
uloc = allocate(state.dimensions, ("xt", "yt", "isle"))
vloc = allocate(state.dimensions, ("xt", "yt", "isle"))
uloc = update(
uloc,
at[1:, 1:, :],
-(vs.psin[1:, 1:, :] - vs.psin[1:, :-1, :])
* vs.maskU[1:, 1:, -1, npx.newaxis]
/ vs.dyt[npx.newaxis, 1:, npx.newaxis]
* vs.hur[1:, 1:, npx.newaxis],
)
vloc = update(
vloc,
at[1:, 1:, ...],
(vs.psin[1:, 1:, :] - vs.psin[:-1, 1:, :])
* vs.maskV[1:, 1:, -1, npx.newaxis]
/ (vs.cosu[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
* vs.hvr[1:, 1:, npx.newaxis],
)
vs.line_psin = line_integrals.line_integrals(state, uloc=uloc, vloc=vloc, kind="full")
return KernelOutput(line_psin=vs.line_psin)
@veros_kernel
def boundary_masks(state):
"""
now that the number of islands is known we can allocate the rest of the variables
"""
vs = state.variables
settings = state.settings
boundary_map = vs.land_map[..., npx.newaxis] == npx.arange(1, state.dimensions["isle"] + 1)
if settings.enable_cyclic_x:
vs.line_dir_east_mask = update(
vs.line_dir_east_mask, at[2:-2, 1:-1], boundary_map[3:-1, 1:-1] & ~boundary_map[3:-1, 2:]
)
vs.line_dir_west_mask = update(
vs.line_dir_west_mask, at[2:-2, 1:-1], boundary_map[2:-2, 2:] & ~boundary_map[2:-2, 1:-1]
)
vs.line_dir_south_mask = update(
vs.line_dir_south_mask, at[2:-2, 1:-1], boundary_map[2:-2, 1:-1] & ~boundary_map[3:-1, 1:-1]
)
vs.line_dir_north_mask = update(
vs.line_dir_north_mask, at[2:-2, 1:-1], boundary_map[3:-1, 2:] & ~boundary_map[2:-2, 2:]
)
else:
vs.line_dir_east_mask = update(
vs.line_dir_east_mask, at[1:-1, 1:-1], boundary_map[2:, 1:-1] & ~boundary_map[2:, 2:]
)
vs.line_dir_west_mask = update(
vs.line_dir_west_mask, at[1:-1, 1:-1], boundary_map[1:-1, 2:] & ~boundary_map[1:-1, 1:-1]
)
vs.line_dir_south_mask = update(
vs.line_dir_south_mask, at[1:-1, 1:-1], boundary_map[1:-1, 1:-1] & ~boundary_map[2:, 1:-1]
)
vs.line_dir_north_mask = update(
vs.line_dir_north_mask, at[1:-1, 1:-1], boundary_map[2:, 2:] & ~boundary_map[1:-1, 2:]
)
vs.isle_boundary_mask = ~npx.any(
vs.line_dir_east_mask | vs.line_dir_west_mask | vs.line_dir_south_mask | vs.line_dir_north_mask, axis=2
)
return KernelOutput(
isle_boundary_mask=vs.isle_boundary_mask,
line_dir_east_mask=vs.line_dir_east_mask,
line_dir_west_mask=vs.line_dir_west_mask,
line_dir_south_mask=vs.line_dir_south_mask,
line_dir_north_mask=vs.line_dir_north_mask,
)
from veros.core.operators import numpy as npx
from veros import veros_routine, veros_kernel, KernelOutput
from veros.variables import allocate
from veros.core import numerics, utilities, isoneutral
from veros.core.operators import update, update_add, at
@veros_kernel
def explicit_vert_friction(state):
"""
explicit vertical friction
dissipation is calculated and added to K_diss_v
"""
vs = state.variables
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
flux_top = allocate(state.dimensions, ("xt", "yu", "zt"))
"""
vertical friction of zonal momentum
"""
fxa = 0.5 * (vs.kappaM[1:-2, 1:-2, :-1] + vs.kappaM[2:-1, 1:-2, :-1])
flux_top = update(
flux_top,
at[1:-2, 1:-2, :-1],
fxa
* (vs.u[1:-2, 1:-2, 1:, vs.tau] - vs.u[1:-2, 1:-2, :-1, vs.tau])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskU[1:-2, 1:-2, 1:]
* vs.maskU[1:-2, 1:-2, :-1],
)
flux_top = update(flux_top, at[:, :, -1], 0.0)
vs.du_mix = update(vs.du_mix, at[:, :, 0], flux_top[:, :, 0] / vs.dzt[0] * vs.maskU[:, :, 0])
vs.du_mix = update(
vs.du_mix, at[:, :, 1:], (flux_top[:, :, 1:] - flux_top[:, :, :-1]) / vs.dzt[1:] * vs.maskU[:, :, 1:]
)
"""
diagnose dissipation by vertical friction of zonal momentum
"""
diss = update(
diss,
at[1:-2, 1:-2, :-1],
(vs.u[1:-2, 1:-2, 1:, vs.tau] - vs.u[1:-2, 1:-2, :-1, vs.tau])
* flux_top[1:-2, 1:-2, :-1]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
diss = update(diss, at[:, :, -1], 0.0)
diss = numerics.ugrid_to_tgrid(state, diss)
vs.K_diss_v = vs.K_diss_v + diss
"""
vertical friction of meridional momentum
"""
fxa = 0.5 * (vs.kappaM[1:-2, 1:-2, :-1] + vs.kappaM[1:-2, 2:-1, :-1])
flux_top = update(
flux_top,
at[1:-2, 1:-2, :-1],
fxa
* (vs.v[1:-2, 1:-2, 1:, vs.tau] - vs.v[1:-2, 1:-2, :-1, vs.tau])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskV[1:-2, 1:-2, 1:]
* vs.maskV[1:-2, 1:-2, :-1],
)
flux_top = update(flux_top, at[:, :, -1], 0.0)
vs.dv_mix = update(
vs.dv_mix,
at[:, :, 1:],
(flux_top[:, :, 1:] - flux_top[:, :, :-1]) / vs.dzt[npx.newaxis, npx.newaxis, 1:] * vs.maskV[:, :, 1:],
)
vs.dv_mix = update(vs.dv_mix, at[:, :, 0], flux_top[:, :, 0] / vs.dzt[0] * vs.maskV[:, :, 0])
"""
diagnose dissipation by vertical friction of meridional momentum
"""
diss = update(
diss,
at[1:-2, 1:-2, :-1],
(vs.v[1:-2, 1:-2, 1:, vs.tau] - vs.v[1:-2, 1:-2, :-1, vs.tau])
* flux_top[1:-2, 1:-2, :-1]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
diss = update(diss, at[:, :, -1], 0.0)
diss = numerics.vgrid_to_tgrid(state, diss)
vs.K_diss_v = vs.K_diss_v + diss
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_v=vs.K_diss_v)
@veros_kernel
def implicit_vert_friction(state):
"""
vertical friction
dissipation is calculated and added to K_diss_v
"""
vs = state.variables
settings = state.settings
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
a_tri = allocate(state.dimensions, ("xt", "yu", "zt"))[1:-2, 1:-2]
b_tri = allocate(state.dimensions, ("xt", "yu", "zt"))[1:-2, 1:-2]
c_tri = allocate(state.dimensions, ("xt", "yu", "zt"))[1:-2, 1:-2]
d_tri = allocate(state.dimensions, ("xt", "yu", "zt"))[1:-2, 1:-2]
delta = allocate(state.dimensions, ("xt", "yu", "zt"))[1:-2, 1:-2]
flux_top = allocate(state.dimensions, ("xt", "yu", "zt"))
"""
implicit vertical friction of zonal momentum
"""
kss = npx.maximum(vs.kbot[1:-2, 1:-2], vs.kbot[2:-1, 1:-2])
_, water_mask, edge_mask = utilities.create_water_masks(kss, settings.nz)
fxa = 0.5 * (vs.kappaM[1:-2, 1:-2, :-1] + vs.kappaM[2:-1, 1:-2, :-1])
delta = update(
delta, at[:, :, :-1], settings.dt_mom / vs.dzw[:-1] * fxa * vs.maskU[1:-2, 1:-2, 1:] * vs.maskU[1:-2, 1:-2, :-1]
)
a_tri = update(a_tri, at[:, :, 1:], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri = update(b_tri, at[:, :, 1:], 1 + delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri = update_add(b_tri, at[:, :, 1:-1], delta[:, :, 1:-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:-1])
b_tri_edge = 1 + delta / vs.dzt[npx.newaxis, npx.newaxis, :]
c_tri = update(c_tri, at[...], -delta / vs.dzt[npx.newaxis, npx.newaxis, :])
d_tri = update(d_tri, at[...], vs.u[1:-2, 1:-2, :, vs.tau])
res = utilities.solve_implicit(a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, edge_mask=edge_mask)
vs.u = update(vs.u, at[1:-2, 1:-2, :, vs.taup1], npx.where(water_mask, res, vs.u[1:-2, 1:-2, :, vs.taup1]))
vs.du_mix = update(
vs.du_mix, at[1:-2, 1:-2], (vs.u[1:-2, 1:-2, :, vs.taup1] - vs.u[1:-2, 1:-2, :, vs.tau]) / settings.dt_mom
)
"""
diagnose dissipation by vertical friction of zonal momentum
"""
fxa = 0.5 * (vs.kappaM[1:-2, 1:-2, :-1] + vs.kappaM[2:-1, 1:-2, :-1])
flux_top = update(
flux_top,
at[1:-2, 1:-2, :-1],
fxa
* (vs.u[1:-2, 1:-2, 1:, vs.taup1] - vs.u[1:-2, 1:-2, :-1, vs.taup1])
/ vs.dzw[:-1]
* vs.maskU[1:-2, 1:-2, 1:]
* vs.maskU[1:-2, 1:-2, :-1],
)
diss = update(
diss,
at[1:-2, 1:-2, :-1],
(vs.u[1:-2, 1:-2, 1:, vs.tau] - vs.u[1:-2, 1:-2, :-1, vs.tau]) * flux_top[1:-2, 1:-2, :-1] / vs.dzw[:-1],
)
diss = update(diss, at[:, :, -1], 0.0)
diss = numerics.ugrid_to_tgrid(state, diss)
vs.K_diss_v = vs.K_diss_v + diss
"""
implicit vertical friction of meridional momentum
"""
kss = npx.maximum(vs.kbot[1:-2, 1:-2], vs.kbot[1:-2, 2:-1])
_, water_mask, edge_mask = utilities.create_water_masks(kss, settings.nz)
fxa = 0.5 * (vs.kappaM[1:-2, 1:-2, :-1] + vs.kappaM[1:-2, 2:-1, :-1])
delta = update(
delta,
at[:, :, :-1],
settings.dt_mom
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* fxa
* vs.maskV[1:-2, 1:-2, 1:]
* vs.maskV[1:-2, 1:-2, :-1],
)
a_tri = update(a_tri, at[:, :, 1:], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri = update(b_tri, at[:, :, 1:], 1 + delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri = update_add(b_tri, at[:, :, 1:-1], delta[:, :, 1:-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:-1])
b_tri_edge = 1 + delta / vs.dzt[npx.newaxis, npx.newaxis, :]
c_tri = update(c_tri, at[:, :, :-1], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, :-1])
c_tri = update(c_tri, at[:, :, -1], 0.0)
d_tri = update(d_tri, at[...], vs.v[1:-2, 1:-2, :, vs.tau])
res = utilities.solve_implicit(a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, edge_mask=edge_mask)
vs.v = update(vs.v, at[1:-2, 1:-2, :, vs.taup1], npx.where(water_mask, res, vs.v[1:-2, 1:-2, :, vs.taup1]))
vs.dv_mix = update(
vs.dv_mix, at[1:-2, 1:-2], (vs.v[1:-2, 1:-2, :, vs.taup1] - vs.v[1:-2, 1:-2, :, vs.tau]) / settings.dt_mom
)
"""
diagnose dissipation by vertical friction of meridional momentum
"""
fxa = 0.5 * (vs.kappaM[1:-2, 1:-2, :-1] + vs.kappaM[1:-2, 2:-1, :-1])
flux_top = update(
flux_top,
at[1:-2, 1:-2, :-1],
fxa
* (vs.v[1:-2, 1:-2, 1:, vs.taup1] - vs.v[1:-2, 1:-2, :-1, vs.taup1])
/ vs.dzw[:-1]
* vs.maskV[1:-2, 1:-2, 1:]
* vs.maskV[1:-2, 1:-2, :-1],
)
diss = update(
diss,
at[1:-2, 1:-2, :-1],
(vs.v[1:-2, 1:-2, 1:, vs.tau] - vs.v[1:-2, 1:-2, :-1, vs.tau]) * flux_top[1:-2, 1:-2, :-1] / vs.dzw[:-1],
)
diss = update(diss, at[:, :, -1], 0.0)
diss = numerics.vgrid_to_tgrid(state, diss)
vs.K_diss_v = vs.K_diss_v + diss
return KernelOutput(u=vs.u, v=vs.v, du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_v=vs.K_diss_v)
@veros_kernel
def rayleigh_friction(state):
"""
interior Rayleigh friction
dissipation is calculated and added to K_diss_bot
"""
vs = state.variables
settings = state.settings
vs.du_mix = update_add(vs.du_mix, at[...], -1 * vs.maskU * settings.r_ray * vs.u[..., vs.tau])
if settings.enable_conserve_energy:
diss = vs.maskU * settings.r_ray * vs.u[..., vs.tau] ** 2
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_u(state, diss))
vs.dv_mix = update_add(vs.dv_mix, at[...], -1 * vs.maskV * settings.r_ray * vs.v[..., vs.tau])
if settings.enable_conserve_energy:
diss = vs.maskV * settings.r_ray * vs.v[..., vs.tau] ** 2
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_v(state, diss))
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_bot=vs.K_diss_bot)
@veros_kernel
def linear_bottom_friction(state):
"""
linear bottom friction
dissipation is calculated and added to K_diss_bot
"""
vs = state.variables
settings = state.settings
if settings.enable_bottom_friction_var:
"""
with spatially varying coefficient
"""
k = npx.maximum(vs.kbot[1:-2, 2:-2], vs.kbot[2:-1, 2:-2]) - 1
mask = npx.arange(settings.nz) == k[:, :, npx.newaxis]
vs.du_mix = update_add(
vs.du_mix,
at[1:-2, 2:-2],
-(vs.maskU[1:-2, 2:-2] * vs.r_bot_var_u[1:-2, 2:-2, npx.newaxis]) * vs.u[1:-2, 2:-2, :, vs.tau] * mask,
)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(
diss,
at[1:-2, 2:-2],
vs.maskU[1:-2, 2:-2]
* vs.r_bot_var_u[1:-2, 2:-2, npx.newaxis]
* vs.u[1:-2, 2:-2, :, vs.tau] ** 2
* mask,
)
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_u(state, diss))
k = npx.maximum(vs.kbot[2:-2, 2:-1], vs.kbot[2:-2, 1:-2]) - 1
mask = npx.arange(settings.nz) == k[:, :, npx.newaxis]
vs.dv_mix = update_add(
vs.dv_mix,
at[2:-2, 1:-2],
-(vs.maskV[2:-2, 1:-2] * vs.r_bot_var_v[2:-2, 1:-2, npx.newaxis]) * vs.v[2:-2, 1:-2, :, vs.tau] * mask,
)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(
diss,
at[2:-2, 1:-2],
vs.maskV[2:-2, 1:-2]
* vs.r_bot_var_v[2:-2, 1:-2, npx.newaxis]
* vs.v[2:-2, 1:-2, :, vs.tau] ** 2
* mask,
)
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_v(state, diss))
else:
"""
with constant coefficient
"""
k = npx.maximum(vs.kbot[1:-2, 2:-2], vs.kbot[2:-1, 2:-2]) - 1
mask = npx.arange(settings.nz) == k[:, :, npx.newaxis]
vs.du_mix = update_add(
vs.du_mix, at[1:-2, 2:-2], -1 * vs.maskU[1:-2, 2:-2] * settings.r_bot * vs.u[1:-2, 2:-2, :, vs.tau] * mask
)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(
diss, at[1:-2, 2:-2], vs.maskU[1:-2, 2:-2] * settings.r_bot * vs.u[1:-2, 2:-2, :, vs.tau] ** 2 * mask
)
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_u(state, diss))
k = npx.maximum(vs.kbot[2:-2, 2:-1], vs.kbot[2:-2, 1:-2]) - 1
mask = npx.arange(settings.nz) == k[:, :, npx.newaxis]
vs.dv_mix = update_add(
vs.dv_mix, at[2:-2, 1:-2], -1 * vs.maskV[2:-2, 1:-2] * settings.r_bot * vs.v[2:-2, 1:-2, :, vs.tau] * mask
)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(
diss, at[2:-2, 1:-2], vs.maskV[2:-2, 1:-2] * settings.r_bot * vs.v[2:-2, 1:-2, :, vs.tau] ** 2 * mask
)
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_v(state, diss))
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_bot=vs.K_diss_bot)
@veros_kernel
def quadratic_bottom_friction(state):
"""
quadratic bottom friction
dissipation is calculated and added to K_diss_bot
"""
vs = state.variables
settings = state.settings
# we might want to account for EKE in the drag, also a tidal residual
k = npx.maximum(vs.kbot[1:-2, 2:-2], vs.kbot[2:-1, 2:-2]) - 1
mask = k[..., npx.newaxis] == npx.arange(settings.nz)[npx.newaxis, npx.newaxis, :]
fxa = (
vs.maskV[1:-2, 2:-2, :] * vs.v[1:-2, 2:-2, :, vs.tau] ** 2
+ vs.maskV[1:-2, 1:-3, :] * vs.v[1:-2, 1:-3, :, vs.tau] ** 2
+ vs.maskV[2:-1, 2:-2, :] * vs.v[2:-1, 2:-2, :, vs.tau] ** 2
+ vs.maskV[2:-1, 1:-3, :] * vs.v[2:-1, 1:-3, :, vs.tau] ** 2
)
fxa = npx.sqrt(vs.u[1:-2, 2:-2, :, vs.tau] ** 2 + 0.25 * fxa)
aloc = (
vs.maskU[1:-2, 2:-2, :]
* settings.r_quad_bot
* vs.u[1:-2, 2:-2, :, vs.tau]
* fxa
/ vs.dzt[npx.newaxis, npx.newaxis, :]
* mask
)
vs.du_mix = update_add(vs.du_mix, at[1:-2, 2:-2, :], -aloc)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(diss, at[1:-2, 2:-2, :], aloc * vs.u[1:-2, 2:-2, :, vs.tau])
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_u(state, diss))
k = npx.maximum(vs.kbot[2:-2, 1:-2], vs.kbot[2:-2, 2:-1]) - 1
mask = k[..., npx.newaxis] == npx.arange(settings.nz)[npx.newaxis, npx.newaxis, :]
fxa = (
vs.maskU[2:-2, 1:-2, :] * vs.u[2:-2, 1:-2, :, vs.tau] ** 2
+ vs.maskU[1:-3, 1:-2, :] * vs.u[1:-3, 1:-2, :, vs.tau] ** 2
+ vs.maskU[2:-2, 2:-1, :] * vs.u[2:-2, 2:-1, :, vs.tau] ** 2
+ vs.maskU[1:-3, 2:-1, :] * vs.u[1:-3, 2:-1, :, vs.tau] ** 2
)
fxa = npx.sqrt(vs.v[2:-2, 1:-2, :, vs.tau] ** 2 + 0.25 * fxa)
aloc = (
vs.maskV[2:-2, 1:-2, :]
* settings.r_quad_bot
* vs.v[2:-2, 1:-2, :, vs.tau]
* fxa
/ vs.dzt[npx.newaxis, npx.newaxis, :]
* mask
)
vs.dv_mix = update_add(vs.dv_mix, at[2:-2, 1:-2, :], -aloc)
if settings.enable_conserve_energy:
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(diss, at[2:-2, 1:-2, :], aloc * vs.v[2:-2, 1:-2, :, vs.tau])
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_v(state, diss))
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_bot=vs.K_diss_bot)
@veros_kernel
def harmonic_friction(state):
"""
horizontal harmonic friction
dissipation is calculated and added to K_diss_h
"""
vs = state.variables
settings = state.settings
diss = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_east = allocate(state.dimensions, ("xu", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yu", "zt"))
"""
Zonal velocity
"""
if settings.enable_hor_friction_cos_scaling:
fxa = vs.cost**settings.hor_friction_cosPower
flux_east = update(
flux_east,
at[:-1],
settings.A_h
* fxa[npx.newaxis, :, npx.newaxis]
* (vs.u[1:, :, :, vs.tau] - vs.u[:-1, :, :, vs.tau])
/ (vs.cost * vs.dxt[1:, npx.newaxis])[:, :, npx.newaxis]
* vs.maskU[1:]
* vs.maskU[:-1],
)
fxa = vs.cosu**settings.hor_friction_cosPower
flux_north = update(
flux_north,
at[:, :-1],
settings.A_h
* fxa[npx.newaxis, :-1, npx.newaxis]
* (vs.u[:, 1:, :, vs.tau] - vs.u[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:]
* vs.maskU[:, :-1]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
if settings.enable_noslip_lateral:
flux_north = update_add(
flux_north,
at[:, :-1],
2
* settings.A_h
* fxa[npx.newaxis, :-1, npx.newaxis]
* (vs.u[:, 1:, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:]
* (1 - vs.maskU[:, :-1])
* vs.cosu[npx.newaxis, :-1, npx.newaxis]
- 2
* settings.A_h
* fxa[npx.newaxis, :-1, npx.newaxis]
* (vs.u[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* (1 - vs.maskU[:, 1:])
* vs.maskU[:, :-1]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
else:
flux_east = update(
flux_east,
at[:-1, :, :],
settings.A_h
* (vs.u[1:, :, :, vs.tau] - vs.u[:-1, :, :, vs.tau])
/ (vs.cost * vs.dxt[1:, npx.newaxis])[:, :, npx.newaxis]
* vs.maskU[1:]
* vs.maskU[:-1],
)
flux_north = update(
flux_north,
at[:, :-1, :],
settings.A_h
* (vs.u[:, 1:, :, vs.tau] - vs.u[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:]
* vs.maskU[:, :-1]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
if settings.enable_noslip_lateral:
flux_north = update_add(
flux_north,
at[:, :-1],
2
* settings.A_h
* vs.u[:, 1:, :, vs.tau]
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:]
* (1 - vs.maskU[:, :-1])
* vs.cosu[npx.newaxis, :-1, npx.newaxis]
- 2
* settings.A_h
* vs.u[:, :-1, :, vs.tau]
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* (1 - vs.maskU[:, 1:])
* vs.maskU[:, :-1]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
"""
update tendency
"""
vs.du_mix = update_add(
vs.du_mix,
at[2:-2, 2:-2, :],
vs.maskU[2:-2, 2:-2]
* (
(flux_east[2:-2, 2:-2] - flux_east[1:-3, 2:-2])
/ (vs.cost[2:-2] * vs.dxu[2:-2, npx.newaxis])[:, :, npx.newaxis]
+ (flux_north[2:-2, 2:-2] - flux_north[2:-2, 1:-3])
/ (vs.cost[2:-2] * vs.dyt[2:-2])[npx.newaxis, :, npx.newaxis]
),
)
if settings.enable_conserve_energy:
"""
diagnose dissipation by lateral friction
"""
diss = update(
diss,
at[1:-2, 2:-2],
0.5
* (
(vs.u[2:-1, 2:-2, :, vs.tau] - vs.u[1:-2, 2:-2, :, vs.tau]) * flux_east[1:-2, 2:-2]
+ (vs.u[1:-2, 2:-2, :, vs.tau] - vs.u[:-3, 2:-2, :, vs.tau]) * flux_east[:-3, 2:-2]
)
/ (vs.cost[2:-2] * vs.dxu[1:-2, npx.newaxis])[:, :, npx.newaxis]
+ 0.5
* (
(vs.u[1:-2, 3:-1, :, vs.tau] - vs.u[1:-2, 2:-2, :, vs.tau]) * flux_north[1:-2, 2:-2]
+ (vs.u[1:-2, 2:-2, :, vs.tau] - vs.u[1:-2, 1:-3, :, vs.tau]) * flux_north[1:-2, 1:-3]
)
/ (vs.cost[2:-2] * vs.dyt[2:-2])[npx.newaxis, :, npx.newaxis],
)
vs.K_diss_h = numerics.calc_diss_u(state, diss)
"""
Meridional velocity
"""
if settings.enable_hor_friction_cos_scaling:
flux_east = update(
flux_east,
at[:-1],
settings.A_h
* vs.cosu[npx.newaxis, :, npx.newaxis] ** settings.hor_friction_cosPower
* (vs.v[1:, :, :, vs.tau] - vs.v[:-1, :, :, vs.tau])
/ (vs.cosu * vs.dxu[:-1, npx.newaxis])[:, :, npx.newaxis]
* vs.maskV[1:]
* vs.maskV[:-1],
)
if settings.enable_noslip_lateral:
flux_east = update_add(
flux_east,
at[:-1],
2
* settings.A_h
* fxa[npx.newaxis, :, npx.newaxis]
* vs.v[1:, :, :, vs.tau]
/ (vs.cosu * vs.dxu[:-1, npx.newaxis])[:, :, npx.newaxis]
* vs.maskV[1:]
* (1 - vs.maskV[:-1])
- 2
* settings.A_h
* fxa[npx.newaxis, :, npx.newaxis]
* vs.v[:-1, :, :, vs.tau]
/ (vs.cosu * vs.dxu[:-1, npx.newaxis])[:, :, npx.newaxis]
* (1 - vs.maskV[1:])
* vs.maskV[:-1],
)
flux_north = update(
flux_north,
at[:, :-1],
settings.A_h
* vs.cost[npx.newaxis, 1:, npx.newaxis] ** settings.hor_friction_cosPower
* (vs.v[:, 1:, :, vs.tau] - vs.v[:, :-1, :, vs.tau])
/ vs.dyt[npx.newaxis, 1:, npx.newaxis]
* vs.cost[npx.newaxis, 1:, npx.newaxis]
* vs.maskV[:, :-1]
* vs.maskV[:, 1:],
)
else:
flux_east = update(
flux_east,
at[:-1],
settings.A_h
* (vs.v[1:, :, :, vs.tau] - vs.v[:-1, :, :, vs.tau])
/ (vs.cosu * vs.dxu[:-1, npx.newaxis])[:, :, npx.newaxis]
* vs.maskV[1:]
* vs.maskV[:-1],
)
if settings.enable_noslip_lateral:
flux_east = update_add(
flux_east,
at[:-1],
2
* settings.A_h
* vs.v[1:, :, :, vs.tau]
/ (vs.cosu * vs.dxu[:-1, npx.newaxis])[:, :, npx.newaxis]
* vs.maskV[1:]
* (1 - vs.maskV[:-1])
- 2
* settings.A_h
* vs.v[:-1, :, :, vs.tau]
/ (vs.cosu * vs.dxu[:-1, npx.newaxis])[:, :, npx.newaxis]
* (1 - vs.maskV[1:])
* vs.maskV[:-1],
)
flux_north = update(
flux_north,
at[:, :-1],
settings.A_h
* (vs.v[:, 1:, :, vs.tau] - vs.v[:, :-1, :, vs.tau])
/ vs.dyt[npx.newaxis, 1:, npx.newaxis]
* vs.cost[npx.newaxis, 1:, npx.newaxis]
* vs.maskV[:, :-1]
* vs.maskV[:, 1:],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
"""
update tendency
"""
vs.dv_mix = update_add(
vs.dv_mix,
at[2:-2, 2:-2],
vs.maskV[2:-2, 2:-2]
* (
(flux_east[2:-2, 2:-2] - flux_east[1:-3, 2:-2])
/ (vs.cosu[2:-2] * vs.dxt[2:-2, npx.newaxis])[:, :, npx.newaxis]
+ (flux_north[2:-2, 2:-2] - flux_north[2:-2, 1:-3])
/ (vs.dyu[2:-2] * vs.cosu[2:-2])[npx.newaxis, :, npx.newaxis]
),
)
if settings.enable_conserve_energy:
"""
diagnose dissipation by lateral friction
"""
diss = update(
diss,
at[2:-2, 1:-2],
0.5
* (
(vs.v[3:-1, 1:-2, :, vs.tau] - vs.v[2:-2, 1:-2, :, vs.tau]) * flux_east[2:-2, 1:-2]
+ (vs.v[2:-2, 1:-2, :, vs.tau] - vs.v[1:-3, 1:-2, :, vs.tau]) * flux_east[1:-3, 1:-2]
)
/ (vs.cosu[1:-2] * vs.dxt[2:-2, npx.newaxis])[:, :, npx.newaxis]
+ 0.5
* (
(vs.v[2:-2, 2:-1, :, vs.tau] - vs.v[2:-2, 1:-2, :, vs.tau]) * flux_north[2:-2, 1:-2]
+ (vs.v[2:-2, 1:-2, :, vs.tau] - vs.v[2:-2, :-3, :, vs.tau]) * flux_north[2:-2, :-3]
)
/ (vs.cosu[1:-2] * vs.dyu[1:-2])[npx.newaxis, :, npx.newaxis],
)
vs.K_diss_h = update_add(vs.K_diss_h, at[...], numerics.calc_diss_v(state, diss))
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_h=vs.K_diss_h)
@veros_kernel
def biharmonic_friction(state):
"""
horizontal biharmonic friction
dissipation is calculated and added to K_diss_h
"""
vs = state.variables
settings = state.settings
flux_east = allocate(state.dimensions, ("xu", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yu", "zt"))
visc = npx.sqrt(abs(settings.A_hbi))
# each of these enters twice, so we halve the power
cost_scaled = vs.cost ** (0.5 * settings.biharmonic_friction_cosPower)
cosu_scaled = vs.cosu ** (0.5 * settings.biharmonic_friction_cosPower)
"""
Zonal velocity
"""
flux_east = update(
flux_east,
at[:-1, :, :],
visc
* cost_scaled[npx.newaxis, :, npx.newaxis]
* (vs.u[1:, :, :, vs.tau] - vs.u[:-1, :, :, vs.tau])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
* vs.maskU[1:, :, :]
* vs.maskU[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
visc
* cosu_scaled[npx.newaxis, :-1, npx.newaxis]
* (vs.u[:, 1:, :, vs.tau] - vs.u[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:, :]
* vs.maskU[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
if settings.enable_noslip_lateral:
flux_north = update_add(
flux_north,
at[:, :-1],
2
* visc
* cosu_scaled[npx.newaxis, :-1, npx.newaxis]
* vs.u[:, 1:, :, vs.tau]
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:]
* (1 - vs.maskU[:, :-1])
* vs.cosu[npx.newaxis, :-1, npx.newaxis]
- 2
* visc
* cosu_scaled[npx.newaxis, :-1, npx.newaxis]
* vs.u[:, :-1, :, vs.tau]
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* (1 - vs.maskU[:, 1:])
* vs.maskU[:, :-1]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
del2 = allocate(state.dimensions, ("xt", "yu", "zt"))
del2 = update(
del2,
at[1:, 1:, :],
(flux_east[1:, 1:, :] - flux_east[:-1, 1:, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dxu[1:, npx.newaxis, npx.newaxis])
+ (flux_north[1:, 1:, :] - flux_north[1:, :-1, :])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dyt[npx.newaxis, 1:, npx.newaxis]),
)
flux_east = update(
flux_east,
at[:-1, :, :],
visc
* cost_scaled[npx.newaxis, :, npx.newaxis]
* (del2[1:, :, :] - del2[:-1, :, :])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
* vs.maskU[1:, :, :]
* vs.maskU[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
visc
* cosu_scaled[npx.newaxis, :-1, npx.newaxis]
* (del2[:, 1:, :] - del2[:, :-1, :])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:, :]
* vs.maskU[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
if settings.enable_noslip_lateral:
flux_north = update_add(
flux_north,
at[:, :-1, :],
2
* visc
* cosu_scaled[npx.newaxis, :-1, npx.newaxis]
* del2[:, 1:, :]
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskU[:, 1:, :]
* (1 - vs.maskU[:, :-1, :])
* vs.cosu[npx.newaxis, :-1, npx.newaxis]
- 2
* visc
* cosu_scaled[npx.newaxis, :-1, npx.newaxis]
* del2[:, :-1, :]
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* (1 - vs.maskU[:, 1:, :])
* vs.maskU[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
"""
update tendency
"""
vs.du_mix = update_add(
vs.du_mix,
at[2:-2, 2:-2, :],
-1
* vs.maskU[2:-2, 2:-2, :]
* (
(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxu[2:-2, npx.newaxis, npx.newaxis])
+ (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
if settings.enable_conserve_energy:
"""
diagnose dissipation by lateral friction
"""
flux_east = utilities.enforce_boundaries(flux_east, settings.enable_cyclic_x)
flux_north = utilities.enforce_boundaries(flux_north, settings.enable_cyclic_x)
diss = allocate(state.dimensions, ("xt", "yu", "zt"))
diss = update(
diss,
at[1:-2, 2:-2, :],
-0.5
* (
(vs.u[2:-1, 2:-2, :, vs.tau] - vs.u[1:-2, 2:-2, :, vs.tau]) * flux_east[1:-2, 2:-2, :]
+ (vs.u[1:-2, 2:-2, :, vs.tau] - vs.u[:-3, 2:-2, :, vs.tau]) * flux_east[:-3, 2:-2, :]
)
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxu[1:-2, npx.newaxis, npx.newaxis])
- 0.5
* (
(vs.u[1:-2, 3:-1, :, vs.tau] - vs.u[1:-2, 2:-2, :, vs.tau]) * flux_north[1:-2, 2:-2, :]
+ (vs.u[1:-2, 2:-2, :, vs.tau] - vs.u[1:-2, 1:-3, :, vs.tau]) * flux_north[1:-2, 1:-3, :]
)
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis]),
)
vs.K_diss_h = numerics.calc_diss_u(state, diss)
"""
Meridional velocity
"""
flux_east = update(
flux_east,
at[:-1, :, :],
visc
* cosu_scaled[npx.newaxis, :, npx.newaxis]
* (vs.v[1:, :, :, vs.tau] - vs.v[:-1, :, :, vs.tau])
/ (vs.cosu[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskV[1:, :, :]
* vs.maskV[:-1, :, :],
)
if settings.enable_noslip_lateral:
flux_east = update_add(
flux_east,
at[:-1, :, :],
2
* visc
* cosu_scaled[npx.newaxis, :, npx.newaxis]
* vs.v[1:, :, :, vs.tau]
/ (vs.cosu[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskV[1:, :, :]
* (1 - vs.maskV[:-1, :, :])
- 2
* visc
* cosu_scaled[npx.newaxis, :, npx.newaxis]
* vs.v[:-1, :, :, vs.tau]
/ (vs.cosu[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* (1 - vs.maskV[1:, :, :])
* vs.maskV[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
visc
* cost_scaled[npx.newaxis, :-1, npx.newaxis]
* (vs.v[:, 1:, :, vs.tau] - vs.v[:, :-1, :, vs.tau])
/ vs.dyt[npx.newaxis, 1:, npx.newaxis]
* vs.cost[npx.newaxis, 1:, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.maskV[:, 1:, :],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
del2 = update(
del2,
at[1:, 1:, :],
(flux_east[1:, 1:, :] - flux_east[:-1, 1:, :])
/ (vs.cosu[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
+ (flux_north[1:, 1:, :] - flux_north[1:, :-1, :])
/ (vs.dyu[npx.newaxis, 1:, npx.newaxis] * vs.cosu[npx.newaxis, 1:, npx.newaxis]),
)
flux_east = update(
flux_east,
at[:-1, :, :],
visc
* cosu_scaled[npx.newaxis, :, npx.newaxis]
* (del2[1:, :, :] - del2[:-1, :, :])
/ (vs.cosu[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskV[1:, :, :]
* vs.maskV[:-1, :, :],
)
if settings.enable_noslip_lateral:
flux_east = update_add(
flux_east,
at[:-1, :, :],
2
* visc
* cosu_scaled[npx.newaxis, :, npx.newaxis]
* del2[1:, :, :]
/ (vs.cosu[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskV[1:, :, :]
* (1 - vs.maskV[:-1, :, :])
- 2
* visc
* cosu_scaled[npx.newaxis, :, npx.newaxis]
* del2[:-1, :, :]
/ (vs.cosu[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* (1 - vs.maskV[1:, :, :])
* vs.maskV[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
visc
* cost_scaled[npx.newaxis, :-1, npx.newaxis]
* (del2[:, 1:, :] - del2[:, :-1, :])
/ vs.dyt[npx.newaxis, 1:, npx.newaxis]
* vs.cost[npx.newaxis, 1:, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.maskV[:, 1:, :],
)
flux_east = update(flux_east, at[-1, :, :], 0.0)
flux_north = update(flux_north, at[:, -1, :], 0.0)
"""
update tendency
"""
vs.dv_mix = update_add(
vs.dv_mix,
at[2:-2, 2:-2, :],
-1
* vs.maskV[2:-2, 2:-2, :]
* (
(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cosu[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.dyu[npx.newaxis, 2:-2, npx.newaxis] * vs.cosu[npx.newaxis, 2:-2, npx.newaxis])
),
)
if settings.enable_conserve_energy:
"""
diagnose dissipation by lateral friction
"""
flux_east = utilities.enforce_boundaries(flux_east, settings.enable_cyclic_x)
flux_north = utilities.enforce_boundaries(flux_north, settings.enable_cyclic_x)
diss = update(
diss,
at[2:-2, 1:-2, :],
-0.5
* (
(vs.v[3:-1, 1:-2, :, vs.tau] - vs.v[2:-2, 1:-2, :, vs.tau]) * flux_east[2:-2, 1:-2, :]
+ (vs.v[2:-2, 1:-2, :, vs.tau] - vs.v[1:-3, 1:-2, :, vs.tau]) * flux_east[1:-3, 1:-2, :]
)
/ (vs.cosu[npx.newaxis, 1:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
- 0.5
* (
(vs.v[2:-2, 2:-1, :, vs.tau] - vs.v[2:-2, 1:-2, :, vs.tau]) * flux_north[2:-2, 1:-2, :]
+ (vs.v[2:-2, 1:-2, :, vs.tau] - vs.v[2:-2, :-3, :, vs.tau]) * flux_north[2:-2, :-3, :]
)
/ (vs.cosu[npx.newaxis, 1:-2, npx.newaxis] * vs.dyu[npx.newaxis, 1:-2, npx.newaxis]),
)
vs.K_diss_h = update_add(vs.K_diss_h, at[...], numerics.calc_diss_v(state, diss))
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_h=vs.K_diss_h)
@veros_kernel
def momentum_sources(state):
"""
other momentum sources
dissipation is calculated and added to K_diss_bot
"""
vs = state.variables
settings = state.settings
vs.du_mix = update_add(vs.du_mix, at[...], vs.maskU * vs.u_source)
if settings.enable_conserve_energy:
diss = -1 * vs.maskU * vs.u[..., vs.tau] * vs.u_source
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_u(state, diss))
vs.dv_mix = update_add(vs.dv_mix, at[...], vs.maskV * vs.v_source)
if settings.enable_conserve_energy:
diss = -1 * vs.maskV * vs.v[..., vs.tau] * vs.v_source
vs.K_diss_bot = update_add(vs.K_diss_bot, at[...], numerics.calc_diss_v(state, diss))
return KernelOutput(du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_bot=vs.K_diss_bot)
@veros_routine
def friction(state):
vs = state.variables
settings = state.settings
"""
vertical friction
"""
vs.K_diss_v = update(vs.K_diss_v, at[...], 0.0)
if settings.enable_implicit_vert_friction:
vs.update(implicit_vert_friction(state))
if settings.enable_explicit_vert_friction:
vs.update(explicit_vert_friction(state))
"""
TEM formalism for eddy-driven velocity
"""
if settings.enable_TEM_friction:
vs.update(isoneutral.isoneutral_friction(state))
"""
horizontal friction
"""
if settings.enable_hor_friction:
vs.update(harmonic_friction(state))
if settings.enable_biharmonic_friction:
vs.update(biharmonic_friction(state))
"""
Rayleigh and bottom friction
"""
vs.K_diss_bot = update(vs.K_diss_bot, at[...], 0.0)
if settings.enable_ray_friction:
vs.update(rayleigh_friction(state))
if settings.enable_bottom_friction:
vs.update(linear_bottom_friction(state))
if settings.enable_quadratic_bottom_friction:
vs.update(quadratic_bottom_friction(state))
"""
add user defined forcing
"""
if settings.enable_momentum_sources:
vs.update(momentum_sources(state))
from veros.core.operators import numpy as npx
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core import advection, utilities
from veros.core.operators import update, update_add, at
"""
IDEMIX as in Olbers and Eden, 2013
"""
@veros_kernel
def set_idemix_parameter(state):
"""
set main IDEMIX parameter
"""
vs = state.variables
settings = state.settings
bN0 = (
npx.sum(
npx.sqrt(npx.maximum(0.0, vs.Nsqr[:, :, :-1, vs.tau]))
* vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskW[:, :, :-1],
axis=2,
)
+ npx.sqrt(npx.maximum(0.0, vs.Nsqr[:, :, -1, vs.tau])) * 0.5 * vs.dzw[-1:] * vs.maskW[:, :, -1]
)
fxa = npx.sqrt(npx.maximum(0.0, vs.Nsqr[..., vs.tau])) / (1e-22 + npx.abs(vs.coriolis_t[..., npx.newaxis]))
cstar = npx.maximum(1e-2, bN0[:, :, npx.newaxis] / (settings.pi * settings.jstar))
vs.c0 = npx.maximum(0.0, settings.gamma * cstar * gofx2(fxa, settings.pi) * vs.maskW)
vs.v0 = npx.maximum(0.0, settings.gamma * cstar * hofx1(fxa, settings.pi) * vs.maskW)
vs.alpha_c = (
npx.maximum(
1e-4,
settings.mu0 * npx.arccosh(npx.maximum(1.0, fxa)) * npx.abs(vs.coriolis_t[..., npx.newaxis]) / cstar**2,
)
* vs.maskW
)
return KernelOutput(c0=vs.c0, v0=vs.v0, alpha_c=vs.alpha_c)
@veros_kernel
def integrate_idemix_kernel(state):
"""
integrate idemix on W grid
"""
vs = state.variables
settings = state.settings
a_tri, b_tri, c_tri, d_tri, delta = (allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2] for _ in range(5))
forc = allocate(state.dimensions, ("xt", "yt", "zt"))
maxE_iw = allocate(state.dimensions, ("xt", "yt", "zt"))
"""
forcing by EKE dissipation
"""
if settings.enable_eke:
forc = vs.eke_diss_iw
else: # shortcut without EKE model
forc = vs.K_diss_gm + vs.K_diss_h - vs.P_diss_skew
if settings.enable_store_cabbeling_heat:
forc += -vs.P_diss_hmix - vs.P_diss_iso
if settings.enable_eke and (settings.enable_eke_diss_bottom or settings.enable_eke_diss_surfbot):
"""
vertically integrate EKE dissipation and inject at bottom and/or surface
"""
a_loc = npx.sum(vs.dzw[npx.newaxis, npx.newaxis, :-1] * forc[:, :, :-1] * vs.maskW[:, :, :-1], axis=2)
a_loc += 0.5 * forc[:, :, -1] * vs.maskW[:, :, -1] * vs.dzw[-1]
forc = update(forc, at[...], 0.0)
ks = npx.maximum(0, vs.kbot[2:-2, 2:-2] - 1)
mask = ks[:, :, npx.newaxis] == npx.arange(settings.nz)[npx.newaxis, npx.newaxis, :]
if settings.enable_eke_diss_bottom:
forc = update(
forc,
at[2:-2, 2:-2, :],
npx.where(
mask, a_loc[2:-2, 2:-2, npx.newaxis] / vs.dzw[npx.newaxis, npx.newaxis, :], forc[2:-2, 2:-2, :]
),
)
else:
forc = update(
forc,
at[2:-2, 2:-2, :],
npx.where(
mask,
settings.eke_diss_surfbot_frac
* a_loc[2:-2, 2:-2, npx.newaxis]
/ vs.dzw[npx.newaxis, npx.newaxis, :],
forc[2:-2, 2:-2, :],
),
)
forc = update(
forc,
at[2:-2, 2:-2, -1],
(1.0 - settings.eke_diss_surfbot_frac) * a_loc[2:-2, 2:-2] / (0.5 * vs.dzw[-1]),
)
"""
forcing by bottom friction
"""
if not settings.enable_store_bottom_friction_tke:
forc = forc + vs.K_diss_bot
"""
prevent negative dissipation of IW energy
"""
maxE_iw = npx.maximum(0.0, vs.E_iw[:, :, :, vs.tau])
"""
vertical diffusion and dissipation is solved implicitly
"""
_, water_mask, edge_mask = utilities.create_water_masks(vs.kbot[2:-2, 2:-2], settings.nz)
delta = update(
delta,
at[:, :, :-1],
settings.dt_tracer
* settings.tau_v
/ vs.dzt[npx.newaxis, npx.newaxis, 1:]
* 0.5
* (vs.c0[2:-2, 2:-2, :-1] + vs.c0[2:-2, 2:-2, 1:]),
)
delta = update(delta, at[:, :, -1], 0.0)
a_tri = update(
a_tri, at[:, :, 1:-1], -delta[:, :, :-2] * vs.c0[2:-2, 2:-2, :-2] / vs.dzw[npx.newaxis, npx.newaxis, 1:-1]
)
a_tri = update(a_tri, at[:, :, -1], -delta[:, :, -2] / (0.5 * vs.dzw[-1:]) * vs.c0[2:-2, 2:-2, -2])
b_tri = update(
b_tri,
at[:, :, 1:-1],
1
+ delta[:, :, 1:-1] * vs.c0[2:-2, 2:-2, 1:-1] / vs.dzw[npx.newaxis, npx.newaxis, 1:-1]
+ delta[:, :, :-2] * vs.c0[2:-2, 2:-2, 1:-1] / vs.dzw[npx.newaxis, npx.newaxis, 1:-1]
+ settings.dt_tracer * vs.alpha_c[2:-2, 2:-2, 1:-1] * maxE_iw[2:-2, 2:-2, 1:-1],
)
b_tri = update(
b_tri,
at[:, :, -1],
1
+ delta[:, :, -2] / (0.5 * vs.dzw[-1:]) * vs.c0[2:-2, 2:-2, -1]
+ settings.dt_tracer * vs.alpha_c[2:-2, 2:-2, -1] * maxE_iw[2:-2, 2:-2, -1],
)
b_tri_edge = (
1
+ delta / vs.dzw * vs.c0[2:-2, 2:-2, :]
+ settings.dt_tracer * vs.alpha_c[2:-2, 2:-2, :] * maxE_iw[2:-2, 2:-2, :]
)
c_tri = update(
c_tri, at[:, :, :-1], -delta[:, :, :-1] / vs.dzw[npx.newaxis, npx.newaxis, :-1] * vs.c0[2:-2, 2:-2, 1:]
)
d_tri = update(d_tri, at[...], vs.E_iw[2:-2, 2:-2, :, vs.tau] + settings.dt_tracer * forc[2:-2, 2:-2, :])
d_tri_edge = (
d_tri + settings.dt_tracer * vs.forc_iw_bottom[2:-2, 2:-2, npx.newaxis] / vs.dzw[npx.newaxis, npx.newaxis, :]
)
d_tri = update_add(d_tri, at[:, :, -1], settings.dt_tracer * vs.forc_iw_surface[2:-2, 2:-2] / (0.5 * vs.dzw[-1:]))
sol = utilities.solve_implicit(
a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, d_edge=d_tri_edge, edge_mask=edge_mask
)
vs.E_iw = update(vs.E_iw, at[2:-2, 2:-2, :, vs.taup1], npx.where(water_mask, sol, vs.E_iw[2:-2, 2:-2, :, vs.taup1]))
"""
store IW dissipation
"""
vs.iw_diss = vs.alpha_c * maxE_iw * vs.E_iw[..., vs.taup1]
"""
add tendency due to lateral diffusion
"""
flux_east = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_top = allocate(state.dimensions, ("xt", "yt", "zt"))
if settings.enable_idemix_hor_diffusion:
flux_east = update(
flux_east,
at[:-1, :, :],
settings.tau_h
* 0.5
* (vs.v0[1:, :, :] + vs.v0[:-1, :, :])
* (vs.v0[1:, :, :] * vs.E_iw[1:, :, :, vs.tau] - vs.v0[:-1, :, :] * vs.E_iw[:-1, :, :, vs.tau])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskU[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
settings.tau_h
* 0.5
* (vs.v0[:, 1:, :] + vs.v0[:, :-1, :])
* (vs.v0[:, 1:, :] * vs.E_iw[:, 1:, :, vs.tau] - vs.v0[:, :-1, :] * vs.E_iw[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_north = update(flux_north, at[:, -1, :], 0.0)
vs.E_iw = update_add(
vs.E_iw,
at[2:-2, 2:-2, :, vs.taup1],
settings.dt_tracer
* vs.maskW[2:-2, 2:-2, :]
* (
(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
"""
add tendency due to advection
"""
if settings.enable_idemix_superbee_advection:
flux_east, flux_north, flux_top = advection.adv_flux_superbee_wgrid(state, vs.E_iw[:, :, :, vs.tau])
if settings.enable_idemix_upwind_advection:
flux_east, flux_north, flux_top = advection.adv_flux_upwind_wgrid(state, vs.E_iw[:, :, :, vs.tau])
if settings.enable_idemix_superbee_advection or settings.enable_idemix_upwind_advection:
vs.dE_iw = update(
vs.dE_iw,
at[2:-2, 2:-2, :, vs.tau],
vs.maskW[2:-2, 2:-2, :]
* (
-(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
- (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
vs.dE_iw = update_add(vs.dE_iw, at[:, :, 0, vs.tau], -flux_top[:, :, 0] / vs.dzw[0:1])
vs.dE_iw = update_add(
vs.dE_iw,
at[:, :, 1:-1, vs.tau],
-(flux_top[:, :, 1:-1] - flux_top[:, :, :-2]) / vs.dzw[npx.newaxis, npx.newaxis, 1:-1],
)
vs.dE_iw = update_add(
vs.dE_iw, at[:, :, -1, vs.tau], -(flux_top[:, :, -1] - flux_top[:, :, -2]) / (0.5 * vs.dzw[-1:])
)
"""
Adam Bashforth time stepping
"""
vs.E_iw = update_add(
vs.E_iw,
at[:, :, :, vs.taup1],
settings.dt_tracer
* (
(1.5 + settings.AB_eps) * vs.dE_iw[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.dE_iw[:, :, :, vs.taum1]
),
)
return KernelOutput(E_iw=vs.E_iw, dE_iw=vs.dE_iw, iw_diss=vs.iw_diss)
@veros_kernel
def gofx2(x, pi):
x = npx.maximum(3.0, x)
c = 1.0 - (2.0 / pi) * npx.arcsin(1.0 / x)
return 2.0 / pi / c * 0.9 * x ** (-2.0 / 3.0) * (1 - npx.exp(-x / 4.3))
@veros_kernel
def hofx1(x, pi):
valid = x > 1
# replace with dummy value to prevent NaNs
x = npx.where(valid, x, 2)
return npx.where(
valid,
(2.0 / pi) / (1.0 - (2.0 / pi) * npx.arcsin(1.0 / x)) * (x - 1.0) / (x + 1.0),
0,
)
@veros_routine
def integrate_idemix(state):
vs = state.variables
vs.update(integrate_idemix_kernel(state))
from veros.core.isoneutral.isoneutral import ( # noqa: F401
check_isoneutral_slope_crit,
isoneutral_diffusion_pre,
isoneutral_diag_streamfunction,
)
from veros.core.isoneutral.diffusion import ( # noqa: F401
isoneutral_diffusion,
isoneutral_skew_diffusion,
)
from veros.core.isoneutral.friction import ( # noqa: F401
isoneutral_friction,
)
from veros.core.operators import numpy as npx
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core import utilities, diffusion
from veros.core.operators import update, update_add, at
@veros_kernel
def _calc_tracer_fluxes(state, tr, K_iso, K_skew):
vs = state.variables
tr_pad = utilities.pad_z_edges(tr[..., vs.tau])
K1 = K_iso - K_skew
K2 = K_iso + K_skew
flux_east = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_top = allocate(state.dimensions, ("xt", "yt", "zt"))
"""
construct total isoneutral tracer flux at east face of 'T' cells
"""
diffloc = allocate(state.dimensions, ("xt", "yt", "zt"))[1:-2, 2:-2]
diffloc = update(
diffloc,
at[:, :, 1:],
0.25 * (K1[1:-2, 2:-2, 1:] + K1[1:-2, 2:-2, :-1] + K1[2:-1, 2:-2, 1:] + K1[2:-1, 2:-2, :-1]),
)
diffloc = update(diffloc, at[:, :, 0], 0.5 * (K1[1:-2, 2:-2, 0] + K1[2:-1, 2:-2, 0]))
sumz = 0.0
for kr in range(2):
for ip in range(2):
sumz = sumz + diffloc * vs.Ai_ez[1:-2, 2:-2, :, ip, kr] * (
tr_pad[1 + ip : -2 + ip, 2:-2, 1 + kr : -1 + kr or None] - tr_pad[1 + ip : -2 + ip, 2:-2, kr : -2 + kr]
)
flux_east = update(
flux_east,
at[1:-2, 2:-2, :],
sumz / (4.0 * vs.dzt[npx.newaxis, npx.newaxis, :])
+ (tr[2:-1, 2:-2, :, vs.tau] - tr[1:-2, 2:-2, :, vs.tau])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxu[1:-2, npx.newaxis, npx.newaxis])
* vs.K_11[1:-2, 2:-2, :],
)
"""
construct total isoneutral tracer flux at north face of 'T' cells
"""
diffloc = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 1:-2]
diffloc = update(
diffloc,
at[:, :, 1:],
0.25 * (K1[2:-2, 1:-2, 1:] + K1[2:-2, 1:-2, :-1] + K1[2:-2, 2:-1, 1:] + K1[2:-2, 2:-1, :-1]),
)
diffloc = update(diffloc, at[:, :, 0], 0.5 * (K1[2:-2, 1:-2, 0] + K1[2:-2, 2:-1, 0]))
sumz = 0.0
for kr in range(2):
for jp in range(2):
sumz = sumz + diffloc * vs.Ai_nz[2:-2, 1:-2, :, jp, kr] * (
tr_pad[2:-2, 1 + jp : -2 + jp, 1 + kr : -1 + kr or None] - tr_pad[2:-2, 1 + jp : -2 + jp, kr : -2 + kr]
)
flux_north = update(
flux_north,
at[2:-2, 1:-2, :],
vs.cosu[npx.newaxis, 1:-2, npx.newaxis]
* (
sumz / (4.0 * vs.dzt[npx.newaxis, npx.newaxis, :])
+ (tr[2:-2, 2:-1, :, vs.tau] - tr[2:-2, 1:-2, :, vs.tau])
/ vs.dyu[npx.newaxis, 1:-2, npx.newaxis]
* vs.K_22[2:-2, 1:-2, :]
),
)
"""
compute the vertical tracer flux 'flux_top' containing the K31
and K32 components which are to be solved explicitly. The K33
component will be treated implicitly. Note that there are some
cancellations of dxu(i-1+ip) and dyu(jrow-1+jp)
"""
diffloc = K2[2:-2, 2:-2, :-1]
sumx = 0.0
for ip in range(2):
for kr in range(2):
sumx = sumx + diffloc * vs.Ai_bx[2:-2, 2:-2, :-1, ip, kr] / vs.cost[npx.newaxis, 2:-2, npx.newaxis] * (
tr[2 + ip : -2 + ip, 2:-2, kr : -1 + kr or None, vs.tau]
- tr[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None, vs.tau]
)
sumy = 0.0
for jp in range(2):
for kr in range(2):
sumy = sumy + diffloc * vs.Ai_by[2:-2, 2:-2, :-1, jp, kr] * vs.cosu[
npx.newaxis, 1 + jp : -3 + jp, npx.newaxis
] * (
tr[2:-2, 2 + jp : -2 + jp, kr : -1 + kr or None, vs.tau]
- tr[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None, vs.tau]
)
flux_top = update(
flux_top,
at[2:-2, 2:-2, :-1],
sumx / (4 * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ sumy / (4 * vs.dyt[npx.newaxis, 2:-2, npx.newaxis] * vs.cost[npx.newaxis, 2:-2, npx.newaxis]),
)
flux_top = update(flux_top, at[:, :, -1], 0.0)
return flux_east, flux_north, flux_top
@veros_kernel
def _calc_explicit_part(state, flux_east, flux_north, flux_top):
vs = state.variables
explicit_part = allocate(state.dimensions, ("xt", "yt", "zt"))
explicit_part = update(
explicit_part,
at[2:-2, 2:-2, :],
vs.maskT[2:-2, 2:-2, :]
* (
(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
explicit_part = update_add(explicit_part, at[:, :, 0], vs.maskT[:, :, 0] * flux_top[:, :, 0] / vs.dzt[0])
explicit_part = update_add(
explicit_part,
at[:, :, 1:],
vs.maskT[:, :, 1:] * (flux_top[:, :, 1:] - flux_top[:, :, :-1]) / vs.dzt[npx.newaxis, npx.newaxis, 1:],
)
return explicit_part
@veros_kernel
def _calc_implicit_part(state, tr):
vs = state.variables
settings = state.settings
_, water_mask, edge_mask = utilities.create_water_masks(vs.kbot[2:-2, 2:-2], settings.nz)
a_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
b_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
c_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
delta = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
delta = update(
delta, at[:, :, :-1], settings.dt_tracer / vs.dzw[npx.newaxis, npx.newaxis, :-1] * vs.K_33[2:-2, 2:-2, :-1]
)
delta = update(delta, at[:, :, -1], 0.0)
a_tri = update(a_tri, at[:, :, 1:], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri = update(
b_tri, at[:, :, 1:-1], 1 + (delta[:, :, 1:-1] + delta[:, :, :-2]) / vs.dzt[npx.newaxis, npx.newaxis, 1:-1]
)
b_tri = update(b_tri, at[:, :, -1], 1 + delta[:, :, -2] / vs.dzt[npx.newaxis, npx.newaxis, -1])
b_tri_edge = 1 + (delta[:, :, :] / vs.dzt[npx.newaxis, npx.newaxis, :])
c_tri = update(c_tri, at[:, :, :-1], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, :-1])
sol = utilities.solve_implicit(
a_tri, b_tri, c_tri, tr[2:-2, 2:-2, :, vs.taup1], water_mask, b_edge=b_tri_edge, edge_mask=edge_mask
)
implicit_part = npx.where(water_mask, sol, tr[2:-2, 2:-2, :, vs.taup1])
return implicit_part
@veros_kernel(static_args=("iso", "skew"))
def isoneutral_diffusion_tracer(state, tr, dtracer_iso, iso=True, skew=False):
"""
Isoneutral diffusion for general tracers
"""
vs = state.variables
settings = state.settings
if iso:
K_iso = vs.K_iso
else:
K_iso = 0.0
if skew:
K_skew = vs.K_gm
else:
K_skew = 0.0
flux_east, flux_north, flux_top = _calc_tracer_fluxes(state, tr, K_iso, K_skew)
"""
add explicit part
"""
dtr = _calc_explicit_part(state, flux_east, flux_north, flux_top)
dtracer_iso = dtracer_iso + dtr
tr = update_add(tr, at[2:-2, 2:-2, :, vs.taup1], settings.dt_tracer * dtr[2:-2, 2:-2, :])
"""
add implicit part
"""
if iso:
new_tr = update(tr, at[2:-2, 2:-2, :, vs.taup1], _calc_implicit_part(state, tr))
dtracer_iso = dtracer_iso + (new_tr[:, :, :, vs.taup1] - tr[:, :, :, vs.taup1]) / settings.dt_tracer
tr = new_tr
return tr, dtracer_iso, flux_east, flux_north, flux_top
@veros_kernel(static_args=("istemp", "iso"))
def isoneutral_diffusion_kernel(state, tr, istemp, iso=True):
vs = state.variables
settings = state.settings
if istemp:
dtracer_iso = vs.dtemp_iso
else:
dtracer_iso = vs.dsalt_iso
tr, dtracer_iso, flux_east, flux_north, flux_top = isoneutral_diffusion_tracer(
state, tr, dtracer_iso, iso=iso, skew=not iso
)
out = {}
if istemp:
out.update(temp=tr, dtemp_iso=dtracer_iso)
else:
out.update(salt=tr, dsalt_iso=dtracer_iso)
"""
dissipation by isopycnal mixing
"""
if settings.enable_conserve_energy:
if istemp:
int_drhodX = vs.int_drhodT[:, :, :, vs.tau]
else:
int_drhodX = vs.int_drhodS[:, :, :, vs.tau]
"""
dissipation interpolated on W-grid
"""
diss = diffusion.compute_dissipation(state, int_drhodX, flux_east, flux_north)
diss_wgrid = diffusion.dissipation_on_wgrid(state, diss, vs.kbot)
if not iso:
vs.P_diss_skew = vs.P_diss_skew + diss_wgrid
else:
vs.P_diss_iso = vs.P_diss_iso + diss_wgrid
"""
diagnose dissipation of dynamic enthalpy by explicit and implicit vertical mixing
"""
fxa = (-int_drhodX[2:-2, 2:-2, 1:] + int_drhodX[2:-2, 2:-2, :-1]) / vs.dzw[npx.newaxis, npx.newaxis, :-1]
if not iso:
vs.P_diss_skew = update_add(
vs.P_diss_skew,
at[2:-2, 2:-2, :-1],
-settings.grav / settings.rho_0 * fxa * flux_top[2:-2, 2:-2, :-1] * vs.maskW[2:-2, 2:-2, :-1],
)
out["P_diss_skew"] = vs.P_diss_skew
else:
vs.P_diss_iso = update_add(
vs.P_diss_iso,
at[2:-2, 2:-2, :-1],
-settings.grav
/ settings.rho_0
* fxa
* (
flux_top[2:-2, 2:-2, :-1] * vs.maskW[2:-2, 2:-2, :-1]
+ vs.K_33[2:-2, 2:-2, :-1]
* (tr[2:-2, 2:-2, 1:, vs.taup1] - tr[2:-2, 2:-2, :-1, vs.taup1])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskW[2:-2, 2:-2, :-1]
),
)
out["P_diss_iso"] = vs.P_diss_iso
return KernelOutput(**out)
@veros_routine
def isoneutral_diffusion(state, tr, istemp):
"""
Isopycnal diffusion for tracer,
following functional formulation by Griffies et al
Dissipation is calculated and stored in P_diss_iso
T/S changes are added to dtemp_iso/dsalt_iso
"""
vs = state.variables
vs.update(isoneutral_diffusion_kernel(state, tr, istemp, iso=True))
@veros_routine
def isoneutral_skew_diffusion(state, tr, istemp):
"""
Isopycnal skew diffusion for tracer,
following functional formulation by Griffies et al
Dissipation is calculated and stored in P_diss_skew
T/S changes are added to dtemp_iso/dsalt_iso
"""
vs = state.variables
vs.update(isoneutral_diffusion_kernel(state, tr, istemp, iso=False))
from veros.core.operators import numpy as npx
from veros import veros_kernel, KernelOutput
from veros.variables import allocate
from veros.core import numerics, utilities
from veros.core.operators import update, update_add, at
@veros_kernel
def isoneutral_friction(state):
"""
vertical friction using TEM formalism for eddy driven velocity
"""
vs = state.variables
settings = state.settings
flux_top = allocate(state.dimensions, ("xt", "yt", "zt"))
delta, a_tri, b_tri, c_tri = (allocate(state.dimensions, ("xt", "yt", "zt"))[1:-2, 1:-2, :] for _ in range(4))
if settings.enable_implicit_vert_friction:
aloc = vs.u[:, :, :, vs.taup1]
else:
aloc = vs.u[:, :, :, vs.tau]
# implicit vertical friction of zonal momentum by GM
ks = npx.maximum(vs.kbot[1:-2, 1:-2], vs.kbot[2:-1, 1:-2])
_, water_mask, edge_mask = utilities.create_water_masks(ks, settings.nz)
fxa = 0.5 * (vs.kappa_gm[1:-2, 1:-2, :] + vs.kappa_gm[2:-1, 1:-2, :])
delta = update(
delta,
at[:, :, :-1],
settings.dt_mom
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* fxa[:, :, :-1]
* vs.maskU[1:-2, 1:-2, 1:]
* vs.maskU[1:-2, 1:-2, :-1],
)
delta = update(delta, at[..., -1], 0.0)
a_tri = update(a_tri, at[:, :, 1:], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri_edge = 1 + delta / vs.dzt[npx.newaxis, npx.newaxis, :]
b_tri = update(
b_tri,
at[:, :, 1:-1],
1
+ delta[:, :, 1:-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:-1]
+ delta[:, :, :-2] / vs.dzt[npx.newaxis, npx.newaxis, 1:-1],
)
b_tri = update(b_tri, at[:, :, -1], 1 + delta[:, :, -2] / vs.dzt[-1])
c_tri = update(c_tri, at[...], -delta / vs.dzt[npx.newaxis, npx.newaxis, :])
sol = utilities.solve_implicit(
a_tri, b_tri, c_tri, aloc[1:-2, 1:-2, :], water_mask, b_edge=b_tri_edge, edge_mask=edge_mask
)
vs.u = update(vs.u, at[1:-2, 1:-2, :, vs.taup1], npx.where(water_mask, sol, vs.u[1:-2, 1:-2, :, vs.taup1]))
vs.du_mix = update_add(
vs.du_mix,
at[1:-2, 1:-2, :],
(vs.u[1:-2, 1:-2, :, vs.taup1] - aloc[1:-2, 1:-2, :]) / settings.dt_mom * water_mask,
)
if settings.enable_conserve_energy:
# diagnose dissipation
diss = allocate(state.dimensions, ("xt", "yt", "zt"))
fxa = 0.5 * (vs.kappa_gm[1:-2, 1:-2, :-1] + vs.kappa_gm[2:-1, 1:-2, :-1])
flux_top = update(
flux_top,
at[1:-2, 1:-2, :-1],
fxa
* (vs.u[1:-2, 1:-2, 1:, vs.taup1] - vs.u[1:-2, 1:-2, :-1, vs.taup1])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskU[1:-2, 1:-2, 1:]
* vs.maskU[1:-2, 1:-2, :-1],
)
diss = update(
diss,
at[1:-2, 1:-2, :-1],
(vs.u[1:-2, 1:-2, 1:, vs.tau] - vs.u[1:-2, 1:-2, :-1, vs.tau])
* flux_top[1:-2, 1:-2, :-1]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
diss = update(diss, at[:, :, -1], 0.0)
diss = numerics.ugrid_to_tgrid(state, diss)
vs.K_diss_gm = diss
if settings.enable_implicit_vert_friction:
aloc = vs.v[:, :, :, vs.taup1]
else:
aloc = vs.v[:, :, :, vs.tau]
# implicit vertical friction of zonal momentum by GM
ks = npx.maximum(vs.kbot[1:-2, 1:-2], vs.kbot[1:-2, 2:-1])
_, water_mask, edge_mask = utilities.create_water_masks(ks, settings.nz)
fxa = 0.5 * (vs.kappa_gm[1:-2, 1:-2, :] + vs.kappa_gm[1:-2, 2:-1, :])
delta = update(
delta,
at[:, :, :-1],
settings.dt_mom
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* fxa[:, :, :-1]
* vs.maskV[1:-2, 1:-2, 1:]
* vs.maskV[1:-2, 1:-2, :-1],
)
delta = update(delta, at[..., -1], 0.0)
a_tri = update(a_tri, at[:, :, 1:], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri_edge = 1 + delta / vs.dzt[npx.newaxis, npx.newaxis, :]
b_tri = update(
b_tri,
at[:, :, 1:-1],
1
+ delta[:, :, 1:-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:-1]
+ delta[:, :, :-2] / vs.dzt[npx.newaxis, npx.newaxis, 1:-1],
)
b_tri = update(b_tri, at[:, :, -1], 1 + delta[:, :, -2] / vs.dzt[-1])
c_tri = update(c_tri, at[...], -delta / vs.dzt[npx.newaxis, npx.newaxis, :])
sol = utilities.solve_implicit(
a_tri, b_tri, c_tri, aloc[1:-2, 1:-2, :], water_mask, b_edge=b_tri_edge, edge_mask=edge_mask
)
vs.v = update(vs.v, at[1:-2, 1:-2, :, vs.taup1], npx.where(water_mask, sol, vs.v[1:-2, 1:-2, :, vs.taup1]))
vs.dv_mix = update_add(
vs.dv_mix,
at[1:-2, 1:-2, :],
(vs.v[1:-2, 1:-2, :, vs.taup1] - aloc[1:-2, 1:-2, :]) / settings.dt_mom * water_mask,
)
if settings.enable_conserve_energy:
# diagnose dissipation
diss = allocate(state.dimensions, ("xt", "yt", "zt"))
fxa = 0.5 * (vs.kappa_gm[1:-2, 1:-2, :-1] + vs.kappa_gm[1:-2, 2:-1, :-1])
flux_top = update(
flux_top,
at[1:-2, 1:-2, :-1],
fxa
* (vs.v[1:-2, 1:-2, 1:, vs.taup1] - vs.v[1:-2, 1:-2, :-1, vs.taup1])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskV[1:-2, 1:-2, 1:]
* vs.maskV[1:-2, 1:-2, :-1],
)
diss = update(
diss,
at[1:-2, 1:-2, :-1],
(vs.v[1:-2, 1:-2, 1:, vs.tau] - vs.v[1:-2, 1:-2, :-1, vs.tau])
* flux_top[1:-2, 1:-2, :-1]
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
diss = update(diss, at[:, :, -1], 0.0)
diss = numerics.vgrid_to_tgrid(state, diss)
vs.K_diss_gm = vs.K_diss_gm + diss
return KernelOutput(u=vs.u, v=vs.v, du_mix=vs.du_mix, dv_mix=vs.dv_mix, K_diss_gm=vs.K_diss_gm)
from veros.core.operators import numpy as npx
from veros import logger
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core import density, utilities
from veros.core.operators import update, update_add, at
@veros_kernel
def dm_taper(sx, iso_slopec, iso_dslope):
"""
tapering function for isopycnal slopes
"""
return 0.5 * (1.0 + npx.tanh((-npx.abs(sx) + iso_slopec) / iso_dslope))
@veros_kernel
def isoneutral_diffusion_pre(state):
"""
Isopycnal diffusion for tracer
following functional formulation by Griffies et al
Code adopted from MOM2.1
"""
vs = state.variables
settings = state.settings
epsln = 1e-20
dTdx = allocate(state.dimensions, ("xt", "yt", "zt"))
dSdx = allocate(state.dimensions, ("xt", "yt", "zt"))
dTdy = allocate(state.dimensions, ("xt", "yt", "zt"))
dSdy = allocate(state.dimensions, ("xt", "yt", "zt"))
dTdz = allocate(state.dimensions, ("xt", "yt", "zt"))
dSdz = allocate(state.dimensions, ("xt", "yt", "zt"))
"""
drho_dt and drho_ds at centers of T cells
"""
drdT = vs.maskT * density.get_drhodT(state, vs.salt[:, :, :, vs.tau], vs.temp[:, :, :, vs.tau], npx.abs(vs.zt))
drdS = vs.maskT * density.get_drhodS(state, vs.salt[:, :, :, vs.tau], vs.temp[:, :, :, vs.tau], npx.abs(vs.zt))
"""
gradients at top face of T cells
"""
dTdz = update(
dTdz,
at[:, :, :-1],
vs.maskW[:, :, :-1]
* (vs.temp[:, :, 1:, vs.tau] - vs.temp[:, :, :-1, vs.tau])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
dSdz = update(
dSdz,
at[:, :, :-1],
vs.maskW[:, :, :-1]
* (vs.salt[:, :, 1:, vs.tau] - vs.salt[:, :, :-1, vs.tau])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1],
)
"""
gradients at eastern face of T cells
"""
dTdx = update(
dTdx,
at[:-1, :, :],
vs.maskU[:-1, :, :]
* (vs.temp[1:, :, :, vs.tau] - vs.temp[:-1, :, :, vs.tau])
/ (vs.dxu[:-1, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, :, npx.newaxis]),
)
dSdx = update(
dSdx,
at[:-1, :, :],
vs.maskU[:-1, :, :]
* (vs.salt[1:, :, :, vs.tau] - vs.salt[:-1, :, :, vs.tau])
/ (vs.dxu[:-1, npx.newaxis, npx.newaxis] * vs.cost[npx.newaxis, :, npx.newaxis]),
)
"""
gradients at northern face of T cells
"""
dTdy = update(
dTdy,
at[:, :-1, :],
vs.maskV[:, :-1, :]
* (vs.temp[:, 1:, :, vs.tau] - vs.temp[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis],
)
dSdy = update(
dSdy,
at[:, :-1, :],
vs.maskV[:, :-1, :]
* (vs.salt[:, 1:, :, vs.tau] - vs.salt[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis],
)
"""
Compute Ai_ez and K11 on center of east face of T cell.
"""
diffloc = allocate(state.dimensions, ("xt", "yt", "zt"))
diffloc = update(
diffloc,
at[1:-2, 2:-2, 1:],
0.25
* (vs.K_iso[1:-2, 2:-2, 1:] + vs.K_iso[1:-2, 2:-2, :-1] + vs.K_iso[2:-1, 2:-2, 1:] + vs.K_iso[2:-1, 2:-2, :-1]),
)
diffloc = update(diffloc, at[1:-2, 2:-2, 0], 0.5 * (vs.K_iso[1:-2, 2:-2, 0] + vs.K_iso[2:-1, 2:-2, 0]))
sumz = allocate(state.dimensions, ("xt", "yt", "zt"))[1:-2, 2:-2]
for kr in range(2):
ki = 0 if kr == 1 else 1
for ip in range(2):
drodxe = (
drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdx[1:-2, 2:-2, ki:]
+ drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdx[1:-2, 2:-2, ki:]
)
drodze = (
drdT[1 + ip : -2 + ip, 2:-2, ki:] * dTdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None]
+ drdS[1 + ip : -2 + ip, 2:-2, ki:] * dSdz[1 + ip : -2 + ip, 2:-2, : -1 + kr or None]
)
sxe = -drodxe / (npx.minimum(0.0, drodze) - epsln)
taper = dm_taper(sxe, settings.iso_slopec, settings.iso_dslope)
sumz = update_add(
sumz,
at[:, :, ki:],
vs.dzw[npx.newaxis, npx.newaxis, : -1 + kr or None]
* vs.maskU[1:-2, 2:-2, ki:]
* npx.maximum(settings.K_iso_steep, diffloc[1:-2, 2:-2, ki:] * taper),
)
vs.Ai_ez = update(vs.Ai_ez, at[1:-2, 2:-2, ki:, ip, kr], taper * sxe * vs.maskU[1:-2, 2:-2, ki:])
vs.K_11 = update(vs.K_11, at[1:-2, 2:-2, :], sumz / (4.0 * vs.dzt[npx.newaxis, npx.newaxis, :]))
"""
Compute Ai_nz and K_22 on center of north face of T cell.
"""
diffloc = update(diffloc, at[...], 0)
diffloc = update(
diffloc,
at[2:-2, 1:-2, 1:],
0.25
* (vs.K_iso[2:-2, 1:-2, 1:] + vs.K_iso[2:-2, 1:-2, :-1] + vs.K_iso[2:-2, 2:-1, 1:] + vs.K_iso[2:-2, 2:-1, :-1]),
)
diffloc = update(diffloc, at[2:-2, 1:-2, 0], 0.5 * (vs.K_iso[2:-2, 1:-2, 0] + vs.K_iso[2:-2, 2:-1, 0]))
sumz = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 1:-2]
for kr in range(2):
ki = 0 if kr == 1 else 1
for jp in range(2):
drodyn = (
drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdy[2:-2, 1:-2, ki:]
+ drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdy[2:-2, 1:-2, ki:]
)
drodzn = (
drdT[2:-2, 1 + jp : -2 + jp, ki:] * dTdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None]
+ drdS[2:-2, 1 + jp : -2 + jp, ki:] * dSdz[2:-2, 1 + jp : -2 + jp, : -1 + kr or None]
)
syn = -drodyn / (npx.minimum(0.0, drodzn) - epsln)
taper = dm_taper(syn, settings.iso_slopec, settings.iso_dslope)
sumz = update_add(
sumz,
at[:, :, ki:],
vs.dzw[npx.newaxis, npx.newaxis, : -1 + kr or None]
* vs.maskV[2:-2, 1:-2, ki:]
* npx.maximum(settings.K_iso_steep, diffloc[2:-2, 1:-2, ki:] * taper),
)
vs.Ai_nz = update(vs.Ai_nz, at[2:-2, 1:-2, ki:, jp, kr], taper * syn * vs.maskV[2:-2, 1:-2, ki:])
vs.K_22 = update(vs.K_22, at[2:-2, 1:-2, :], sumz / (4.0 * vs.dzt[npx.newaxis, npx.newaxis, :]))
"""
compute Ai_bx, Ai_by and K33 on top face of T cell.
"""
sumx = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2, :-1]
sumy = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2, :-1]
for kr in range(2):
drodzb = (
drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdz[2:-2, 2:-2, :-1]
+ drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdz[2:-2, 2:-2, :-1]
)
# eastward slopes at the top of T cells
for ip in range(2):
drodxb = (
drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None]
+ drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdx[1 + ip : -3 + ip, 2:-2, kr : -1 + kr or None]
)
sxb = -drodxb / (npx.minimum(0.0, drodzb) - epsln)
taper = dm_taper(sxb, settings.iso_slopec, settings.iso_dslope)
sumx = (
sumx
+ vs.dxu[1 + ip : -3 + ip, npx.newaxis, npx.newaxis]
* vs.K_iso[2:-2, 2:-2, :-1]
* taper
* sxb**2
* vs.maskW[2:-2, 2:-2, :-1]
)
vs.Ai_bx = update(vs.Ai_bx, at[2:-2, 2:-2, :-1, ip, kr], taper * sxb * vs.maskW[2:-2, 2:-2, :-1])
# northward slopes at the top of T cells
for jp in range(2):
facty = vs.cosu[1 + jp : -3 + jp] * vs.dyu[1 + jp : -3 + jp]
drodyb = (
drdT[2:-2, 2:-2, kr : -1 + kr or None] * dTdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None]
+ drdS[2:-2, 2:-2, kr : -1 + kr or None] * dSdy[2:-2, 1 + jp : -3 + jp, kr : -1 + kr or None]
)
syb = -drodyb / (npx.minimum(0.0, drodzb) - epsln)
taper = dm_taper(syb, settings.iso_slopec, settings.iso_dslope)
sumy = (
sumy
+ facty[npx.newaxis, :, npx.newaxis]
* vs.K_iso[2:-2, 2:-2, :-1]
* taper
* syb**2
* vs.maskW[2:-2, 2:-2, :-1]
)
vs.Ai_by = update(vs.Ai_by, at[2:-2, 2:-2, :-1, jp, kr], taper * syb * vs.maskW[2:-2, 2:-2, :-1])
vs.K_33 = update(
vs.K_33,
at[2:-2, 2:-2, :-1],
sumx / (4 * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ sumy / (4 * vs.dyt[npx.newaxis, 2:-2, npx.newaxis] * vs.cost[npx.newaxis, 2:-2, npx.newaxis]),
)
vs.K_33 = update(vs.K_33, at[..., -1], 0.0)
return KernelOutput(
Ai_ez=vs.Ai_ez, Ai_nz=vs.Ai_nz, Ai_bx=vs.Ai_bx, Ai_by=vs.Ai_by, K_11=vs.K_11, K_22=vs.K_22, K_33=vs.K_33
)
@veros_kernel
def isoneutral_diag_streamfunction_kernel(state):
vs = state.variables
K_gm_pad = utilities.pad_z_edges(vs.K_gm)
"""
meridional component at east face of 'T' cells
"""
diffloc = 0.25 * (
K_gm_pad[1:-2, 2:-2, 1:-1] + K_gm_pad[1:-2, 2:-2, :-2] + K_gm_pad[2:-1, 2:-2, 1:-1] + K_gm_pad[2:-1, 2:-2, :-2]
)
vs.B2_gm = update(vs.B2_gm, at[1:-2, 2:-2, :], 0.25 * diffloc * npx.sum(vs.Ai_ez[1:-2, 2:-2, ...], axis=(3, 4)))
"""
zonal component at north face of 'T' cells
"""
diffloc = 0.25 * (
K_gm_pad[2:-2, 1:-2, 1:-1] + K_gm_pad[2:-2, 1:-2, :-2] + K_gm_pad[2:-2, 2:-1, 1:-1] + K_gm_pad[2:-2, 2:-1, :-2]
)
vs.B1_gm = update(vs.B1_gm, at[2:-2, 1:-2, :], -0.25 * diffloc * npx.sum(vs.Ai_nz[2:-2, 1:-2, ...], axis=(3, 4)))
return KernelOutput(B1_gm=vs.B1_gm, B2_gm=vs.B2_gm)
@veros_routine
def isoneutral_diag_streamfunction(state):
"""
calculate hor. components of streamfunction for eddy driven velocity
for diagnostics purpose only
"""
vs = state.variables
settings = state.settings
if not (settings.enable_neutral_diffusion and settings.enable_skew_diffusion):
return
vs.update(isoneutral_diag_streamfunction_kernel(state))
@veros_routine
def check_isoneutral_slope_crit(state):
"""
check linear stability criterion from Griffies et al
"""
vs = state.variables
settings = state.settings
epsln = 1e-20
if settings.enable_neutral_diffusion:
ft1 = 1.0 / (4.0 * settings.K_iso_0 * settings.dt_tracer + epsln)
delta1a = npx.min(
vs.dxt[2:-2, npx.newaxis, npx.newaxis]
* npx.abs(vs.cost[npx.newaxis, 2:-2, npx.newaxis])
* vs.dzt[npx.newaxis, npx.newaxis, :]
* ft1
)
delta1b = npx.min(vs.dyt[npx.newaxis, 2:-2, npx.newaxis] * vs.dzt[npx.newaxis, npx.newaxis, :] * ft1)
delta_iso1 = min(vs.dzt[0] * ft1 * vs.dxt[-1] * abs(vs.cost[-1]), min(delta1a, delta1b))
logger.info("Diffusion grid factor delta_iso1 = {}", float(delta_iso1))
if delta_iso1 < settings.iso_slopec:
raise RuntimeError(
"Without latitudinal filtering, delta_iso1 is the steepest "
"isoneutral slope available for linear stability of "
"Redi and GM. Maximum allowable isoneutral slope is "
f"specified as iso_slopec = {settings.iso_slopec}."
)
from veros.core.operators import numpy as npx
from veros import veros_routine, veros_kernel, KernelOutput, runtime_settings
from veros.variables import allocate
from veros.core import friction, external
from veros.core.operators import update, update_add, at
@veros_kernel
def tend_coriolisf(state):
"""
time tendency due to Coriolis force
"""
vs = state.variables
settings = state.settings
vs.du_cor = update(
vs.du_cor,
at[2:-2, 2:-2],
0.25
* vs.maskU[2:-2, 2:-2]
* (
vs.coriolis_t[2:-2, 2:-2, npx.newaxis]
* (vs.v[2:-2, 2:-2, :, vs.tau] + vs.v[2:-2, 1:-3, :, vs.tau])
* vs.dxt[2:-2, npx.newaxis, npx.newaxis]
/ vs.dxu[2:-2, npx.newaxis, npx.newaxis]
+ vs.coriolis_t[3:-1, 2:-2, npx.newaxis]
* (vs.v[3:-1, 2:-2, :, vs.tau] + vs.v[3:-1, 1:-3, :, vs.tau])
* vs.dxt[3:-1, npx.newaxis, npx.newaxis]
/ vs.dxu[2:-2, npx.newaxis, npx.newaxis]
),
)
vs.dv_cor = update(
vs.dv_cor,
at[2:-2, 2:-2],
-0.25
* vs.maskV[2:-2, 2:-2]
* (
vs.coriolis_t[2:-2, 2:-2, npx.newaxis]
* (vs.u[1:-3, 2:-2, :, vs.tau] + vs.u[2:-2, 2:-2, :, vs.tau])
* vs.dyt[npx.newaxis, 2:-2, npx.newaxis]
* vs.cost[npx.newaxis, 2:-2, npx.newaxis]
/ (vs.dyu[npx.newaxis, 2:-2, npx.newaxis] * vs.cosu[npx.newaxis, 2:-2, npx.newaxis])
+ vs.coriolis_t[2:-2, 3:-1, npx.newaxis]
* (vs.u[1:-3, 3:-1, :, vs.tau] + vs.u[2:-2, 3:-1, :, vs.tau])
* vs.dyt[npx.newaxis, 3:-1, npx.newaxis]
* vs.cost[npx.newaxis, 3:-1, npx.newaxis]
/ (vs.dyu[npx.newaxis, 2:-2, npx.newaxis] * vs.cosu[npx.newaxis, 2:-2, npx.newaxis])
),
)
"""
time tendency due to metric terms
"""
if settings.coord_degree:
vs.du_cor = update_add(
vs.du_cor,
at[2:-2, 2:-2],
vs.maskU[2:-2, 2:-2]
* 0.125
* vs.tantr[npx.newaxis, 2:-2, npx.newaxis]
* (
(vs.u[2:-2, 2:-2, :, vs.tau] + vs.u[1:-3, 2:-2, :, vs.tau])
* (vs.v[2:-2, 2:-2, :, vs.tau] + vs.v[2:-2, 1:-3, :, vs.tau])
* vs.dxt[2:-2, npx.newaxis, npx.newaxis]
/ vs.dxu[2:-2, npx.newaxis, npx.newaxis]
+ (vs.u[3:-1, 2:-2, :, vs.tau] + vs.u[2:-2, 2:-2, :, vs.tau])
* (vs.v[3:-1, 2:-2, :, vs.tau] + vs.v[3:-1, 1:-3, :, vs.tau])
* vs.dxt[3:-1, npx.newaxis, npx.newaxis]
/ vs.dxu[2:-2, npx.newaxis, npx.newaxis]
),
)
vs.dv_cor = update_add(
vs.dv_cor,
at[2:-2, 2:-2],
-1
* vs.maskV[2:-2, 2:-2]
* 0.125
* (
vs.tantr[npx.newaxis, 2:-2, npx.newaxis]
* (vs.u[2:-2, 2:-2, :, vs.tau] + vs.u[1:-3, 2:-2, :, vs.tau]) ** 2
* vs.dyt[npx.newaxis, 2:-2, npx.newaxis]
* vs.cost[npx.newaxis, 2:-2, npx.newaxis]
/ (vs.dyu[npx.newaxis, 2:-2, npx.newaxis] * vs.cosu[npx.newaxis, 2:-2, npx.newaxis])
+ vs.tantr[npx.newaxis, 3:-1, npx.newaxis]
* (vs.u[2:-2, 3:-1, :, vs.tau] + vs.u[1:-3, 3:-1, :, vs.tau]) ** 2
* vs.dyt[npx.newaxis, 3:-1, npx.newaxis]
* vs.cost[npx.newaxis, 3:-1, npx.newaxis]
/ (vs.dyu[npx.newaxis, 2:-2, npx.newaxis] * vs.cosu[npx.newaxis, 2:-2, npx.newaxis])
),
)
"""
transfer to time tendencies
"""
vs.du = update(vs.du, at[2:-2, 2:-2, :, vs.tau], vs.du_cor[2:-2, 2:-2])
vs.dv = update(vs.dv, at[2:-2, 2:-2, :, vs.tau], vs.dv_cor[2:-2, 2:-2])
return KernelOutput(du=vs.du, dv=vs.dv, du_cor=vs.du_cor, dv_cor=vs.dv_cor)
@veros_kernel
def tend_tauxyf(state):
"""
wind stress forcing
"""
vs = state.variables
settings = state.settings
if runtime_settings.pyom_compatibility_mode:
# surface_tau* has different units in PyOM
vs.du = update_add(
vs.du, at[2:-2, 2:-2, -1, vs.tau], vs.maskU[2:-2, 2:-2, -1] * vs.surface_taux[2:-2, 2:-2] / vs.dzt[-1]
)
vs.dv = update_add(
vs.dv, at[2:-2, 2:-2, -1, vs.tau], vs.maskV[2:-2, 2:-2, -1] * vs.surface_tauy[2:-2, 2:-2] / vs.dzt[-1]
)
else:
vs.du = update_add(
vs.du,
at[2:-2, 2:-2, -1, vs.tau],
vs.maskU[2:-2, 2:-2, -1] * vs.surface_taux[2:-2, 2:-2] / vs.dzt[-1] / settings.rho_0,
)
vs.dv = update_add(
vs.dv,
at[2:-2, 2:-2, -1, vs.tau],
vs.maskV[2:-2, 2:-2, -1] * vs.surface_tauy[2:-2, 2:-2] / vs.dzt[-1] / settings.rho_0,
)
return KernelOutput(du=vs.du, dv=vs.dv)
@veros_kernel
def momentum_advection(state):
"""
Advection of momentum with second order which is energy conserving
"""
vs = state.variables
"""
Code from MITgcm
"""
utr = vs.u[..., vs.tau] * vs.maskU * vs.dyt[npx.newaxis, :, npx.newaxis] * vs.dzt[npx.newaxis, npx.newaxis, :]
vtr = (
vs.dzt[npx.newaxis, npx.newaxis, :]
* vs.cosu[npx.newaxis, :, npx.newaxis]
* vs.dxt[:, npx.newaxis, npx.newaxis]
* vs.v[..., vs.tau]
* vs.maskV
)
wtr = vs.w[..., vs.tau] * vs.maskW * vs.area_t[:, :, npx.newaxis]
"""
for zonal momentum
"""
flux_east = allocate(state.dimensions, ("xu", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yu", "zt"))
flux_top = allocate(state.dimensions, ("xt", "yt", "zw"))
flux_east = update(
flux_east,
at[1:-2, 2:-2],
0.25 * (vs.u[1:-2, 2:-2, :, vs.tau] + vs.u[2:-1, 2:-2, :, vs.tau]) * (utr[2:-1, 2:-2] + utr[1:-2, 2:-2]),
)
flux_north = update(
flux_north,
at[2:-2, 1:-2],
0.25 * (vs.u[2:-2, 1:-2, :, vs.tau] + vs.u[2:-2, 2:-1, :, vs.tau]) * (vtr[3:-1, 1:-2] + vtr[2:-2, 1:-2]),
)
flux_top = update(
flux_top,
at[2:-2, 2:-2, :-1],
0.25
* (vs.u[2:-2, 2:-2, 1:, vs.tau] + vs.u[2:-2, 2:-2, :-1, vs.tau])
* (wtr[2:-2, 2:-2, :-1] + wtr[3:-1, 2:-2, :-1]),
)
vs.du_adv = update(
vs.du_adv,
at[2:-2, 2:-2],
-1
* vs.maskU[2:-2, 2:-2]
* (flux_east[2:-2, 2:-2] - flux_east[1:-3, 2:-2] + flux_north[2:-2, 2:-2] - flux_north[2:-2, 1:-3])
/ (vs.dzt[npx.newaxis, npx.newaxis, :] * vs.area_u[2:-2, 2:-2, npx.newaxis]),
)
tmp = vs.maskU / (vs.dzt * vs.area_u[:, :, npx.newaxis])
vs.du_adv = vs.du_adv - tmp * flux_top
vs.du_adv = update_add(vs.du_adv, at[:, :, 1:], tmp[:, :, 1:] * flux_top[:, :, :-1])
"""
for meridional momentum
"""
flux_top = update(flux_top, at[...], 0.0)
flux_east = update(
flux_east,
at[1:-2, 2:-2],
0.25 * (vs.v[1:-2, 2:-2, :, vs.tau] + vs.v[2:-1, 2:-2, :, vs.tau]) * (utr[1:-2, 3:-1] + utr[1:-2, 2:-2]),
)
flux_north = update(
flux_north,
at[2:-2, 1:-2],
0.25 * (vs.v[2:-2, 1:-2, :, vs.tau] + vs.v[2:-2, 2:-1, :, vs.tau]) * (vtr[2:-2, 2:-1] + vtr[2:-2, 1:-2]),
)
flux_top = update(
flux_top,
at[2:-2, 2:-2, :-1],
0.25
* (vs.v[2:-2, 2:-2, 1:, vs.tau] + vs.v[2:-2, 2:-2, :-1, vs.tau])
* (wtr[2:-2, 2:-2, :-1] + wtr[2:-2, 3:-1, :-1]),
)
vs.dv_adv = update(
vs.dv_adv,
at[2:-2, 2:-2],
-1
* vs.maskV[2:-2, 2:-2]
* (flux_east[2:-2, 2:-2] - flux_east[1:-3, 2:-2] + flux_north[2:-2, 2:-2] - flux_north[2:-2, 1:-3])
/ (vs.dzt * vs.area_v[2:-2, 2:-2, npx.newaxis]),
)
tmp = vs.maskV / (vs.dzt * vs.area_v[:, :, npx.newaxis])
vs.dv_adv = vs.dv_adv - tmp * flux_top
vs.dv_adv = update_add(vs.dv_adv, at[:, :, 1:], tmp[:, :, 1:] * flux_top[:, :, :-1])
vs.du = update_add(vs.du, at[:, :, :, vs.tau], vs.du_adv)
vs.dv = update_add(vs.dv, at[:, :, :, vs.tau], vs.dv_adv)
return KernelOutput(du=vs.du, dv=vs.dv, du_adv=vs.du_adv, dv_adv=vs.dv_adv)
@veros_routine
def vertical_velocity(state):
vs = state.variables
vs.update(vertical_velocity_kernel(state))
@veros_kernel
def vertical_velocity_kernel(state):
"""
vertical velocity from continuity :
\\int_0^z w_z dz = w(z)-w(0) = - \\int dz (u_x + v_y)
w(z) = -int dz u_x + v_y
"""
vs = state.variables
fxa = allocate(state.dimensions, ("xt", "yt", "zw"))
# integrate from bottom to surface to see error in w
fxa = update(
fxa,
at[1:, 1:, 0],
-1
* vs.maskW[1:, 1:, 0]
* vs.dzt[0]
* (
(vs.u[1:, 1:, 0, vs.taup1] - vs.u[:-1, 1:, 0, vs.taup1])
/ (vs.cost[npx.newaxis, 1:] * vs.dxt[1:, npx.newaxis])
+ (
vs.cosu[npx.newaxis, 1:] * vs.v[1:, 1:, 0, vs.taup1]
- vs.cosu[npx.newaxis, :-1] * vs.v[1:, :-1, 0, vs.taup1]
)
/ (vs.cost[npx.newaxis, 1:] * vs.dyt[npx.newaxis, 1:])
),
)
fxa = update(
fxa,
at[1:, 1:, 1:],
-1
* vs.maskW[1:, 1:, 1:]
* vs.dzt[npx.newaxis, npx.newaxis, 1:]
* (
(vs.u[1:, 1:, 1:, vs.taup1] - vs.u[:-1, 1:, 1:, vs.taup1])
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dxt[1:, npx.newaxis, npx.newaxis])
+ (
vs.cosu[npx.newaxis, 1:, npx.newaxis] * vs.v[1:, 1:, 1:, vs.taup1]
- vs.cosu[npx.newaxis, :-1, npx.newaxis] * vs.v[1:, :-1, 1:, vs.taup1]
)
/ (vs.cost[npx.newaxis, 1:, npx.newaxis] * vs.dyt[npx.newaxis, 1:, npx.newaxis])
),
)
vs.w = update(vs.w, at[1:, 1:, :, vs.taup1], npx.cumsum(fxa[1:, 1:, :], axis=2))
return KernelOutput(w=vs.w)
@veros_routine
def momentum(state):
"""
solve for momentum for taup1
"""
vs = state.variables
"""
time tendency due to Coriolis force
"""
vs.update(tend_coriolisf(state))
"""
wind stress forcing
"""
vs.update(tend_tauxyf(state))
"""
advection
"""
vs.update(momentum_advection(state))
with state.timers["friction"]:
friction.friction(state)
"""
external mode
"""
with state.timers["pressure"]:
if state.settings.enable_streamfunction:
external.solve_streamfunction(state)
else:
external.solve_pressure(state)
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.distributed import global_and
from veros.core import density, diffusion, utilities
from veros.core.operators import update, at, numpy as npx
@veros_kernel
def u_centered_grid(dyt, dyu, yt, yu):
yu = update(yu, at[0], 0)
yu = update(yu, at[1:], npx.cumsum(dyt[1:]))
yt = update(yt, at[0], yu[0] - dyt[0] * 0.5)
yt = update(yt, at[1:], 2 * yu[:-1])
alternating_pattern = npx.ones_like(yt)
alternating_pattern = update(alternating_pattern, at[::2], -1)
yt = update(yt, at[...], alternating_pattern * npx.cumsum(alternating_pattern * yt))
dyu = update(dyu, at[:-1], yt[1:] - yt[:-1])
dyu = update(dyu, at[-1], 2 * dyt[-1] - dyu[-2])
return dyu, yt, yu
@veros_kernel
def calc_grid_spacings_kernel(state):
vs = state.variables
settings = state.settings
if settings.enable_cyclic_x:
vs.dxt = update(vs.dxt, at[-2:], vs.dxt[2:4])
vs.dxt = update(vs.dxt, at[:2], vs.dxt[-4:-2])
else:
vs.dxt = update(vs.dxt, at[-2:], vs.dxt[-3])
vs.dxt = update(vs.dxt, at[:2], vs.dxt[2])
vs.dyt = update(vs.dyt, at[-2:], vs.dyt[-3])
vs.dyt = update(vs.dyt, at[:2], vs.dyt[2])
"""
grid in east/west direction
"""
vs.dxu, vs.xt, vs.xu = u_centered_grid(vs.dxt, vs.dxu, vs.xt, vs.xu)
vs.xt = vs.xt + settings.x_origin - vs.xu[2]
vs.xu = vs.xu + settings.x_origin - vs.xu[2]
if settings.enable_cyclic_x:
vs.xt = update(vs.xt, at[-2:], vs.xt[2:4])
vs.xt = update(vs.xt, at[:2], vs.xt[-4:-2])
vs.xu = update(vs.xu, at[-2:], vs.xt[2:4])
vs.xu = update(vs.xu, at[:2], vs.xu[-4:-2])
vs.dxu = update(vs.dxu, at[-2:], vs.dxu[2:4])
vs.dxu = update(vs.dxu, at[:2], vs.dxu[-4:-2])
"""
grid in north/south direction
"""
vs.dyu, vs.yt, vs.yu = u_centered_grid(vs.dyt, vs.dyu, vs.yt, vs.yu)
vs.yt = vs.yt + settings.y_origin - vs.yu[2]
vs.yu = vs.yu + settings.y_origin - vs.yu[2]
if settings.coord_degree:
"""
convert from degrees to pseudo cartesian grid
"""
vs.dxt = vs.dxt * settings.degtom
vs.dxu = vs.dxu * settings.degtom
vs.dyt = vs.dyt * settings.degtom
vs.dyu = vs.dyu * settings.degtom
"""
grid in vertical direction
"""
vs.dzw, vs.zt, vs.zw = u_centered_grid(vs.dzt, vs.dzw, vs.zt, vs.zw)
vs.zt = vs.zt - vs.zw[-1]
vs.zw = vs.zw - vs.zw[-1] # enforce 0 boundary height
return KernelOutput(
dxt=vs.dxt,
dyt=vs.dyt,
dxu=vs.dxu,
dyu=vs.dyu,
xt=vs.xt,
yt=vs.yt,
xu=vs.xu,
yu=vs.yu,
dzw=vs.dzw,
zt=vs.zt,
zw=vs.zw,
)
@veros_routine(
# all inputs are 1D, so doing this on the main process should be fine
dist_safe=False,
local_variables=(
"dxt",
"dxu",
"xt",
"xu",
"dyt",
"dyu",
"yt",
"yu",
"dzt",
"dzw",
"zt",
"zw",
),
)
def calc_grid_spacings(state):
vs = state.variables
vs.update(calc_grid_spacings_kernel(state))
@veros_kernel
def calc_grid_metrics_kernel(state):
vs = state.variables
settings = state.settings
"""
metric factors
"""
if settings.coord_degree:
vs.cost = update(vs.cost, at[...], npx.cos(vs.yt * settings.pi / 180.0))
vs.cosu = update(vs.cosu, at[...], npx.cos(vs.yu * settings.pi / 180.0))
vs.tantr = update(vs.tantr, at[...], npx.tan(vs.yt * settings.pi / 180.0) / settings.radius)
else:
vs.cost = update(vs.cost, at[...], 1.0)
vs.cosu = update(vs.cosu, at[...], 1.0)
vs.tantr = update(vs.tantr, at[...], 0.0)
"""
precalculate area of boxes
"""
vs.area_t = update(vs.area_t, at[...], vs.cost * vs.dyt * vs.dxt[:, npx.newaxis])
vs.area_u = update(vs.area_u, at[...], vs.cost * vs.dyt * vs.dxu[:, npx.newaxis])
vs.area_v = update(vs.area_v, at[...], vs.cosu * vs.dyu * vs.dxt[:, npx.newaxis])
return KernelOutput(
cost=vs.cost,
cosu=vs.cosu,
tantr=vs.tantr,
area_t=vs.area_t,
area_u=vs.area_u,
area_v=vs.area_v,
)
@veros_routine
def calc_grid(state):
"""
setup grid based on dxt,dyt,dzt and x_origin, y_origin
"""
calc_grid_spacings(state)
vs = state.variables
vs.update(calc_grid_metrics_kernel(state))
@veros_routine
def calc_beta(state):
"""
calculate beta = df/dy
"""
vs = state.variables
settings = state.settings
vs.beta = update(
vs.beta,
at[:, 2:-2],
0.5
* (
(vs.coriolis_t[:, 3:-1] - vs.coriolis_t[:, 2:-2]) / vs.dyu[2:-2]
+ (vs.coriolis_t[:, 2:-2] - vs.coriolis_t[:, 1:-3]) / vs.dyu[1:-3]
),
)
vs.beta = utilities.enforce_boundaries(vs.beta, settings.enable_cyclic_x)
@veros_kernel
def calc_topo_kernel(state):
vs = state.variables
settings = state.settings
"""
close domain
"""
vs.kbot = update(vs.kbot, at[:, :2], 0)
vs.kbot = update(vs.kbot, at[:, -2:], 0)
vs.kbot = utilities.enforce_boundaries(vs.kbot, settings.enable_cyclic_x)
if not settings.enable_cyclic_x:
vs.kbot = update(vs.kbot, at[:2, :], 0)
vs.kbot = update(vs.kbot, at[-2:, :], 0)
"""
Land masks
"""
land_mask = vs.kbot > 0
ks = npx.arange(vs.maskT.shape[2])[npx.newaxis, npx.newaxis, :]
vs.maskT = update(vs.maskT, at[...], land_mask[..., npx.newaxis] & (vs.kbot[..., npx.newaxis] - 1 <= ks))
vs.maskT = utilities.enforce_boundaries(vs.maskT, settings.enable_cyclic_x)
vs.maskU = update(vs.maskU, at[...], vs.maskT)
vs.maskU = update(vs.maskU, at[:-1, :, :], npx.minimum(vs.maskT[:-1, :, :], vs.maskT[1:, :, :]))
vs.maskU = utilities.enforce_boundaries(vs.maskU, settings.enable_cyclic_x)
vs.maskV = update(vs.maskV, at[...], vs.maskT)
vs.maskV = update(vs.maskV, at[:, :-1], npx.minimum(vs.maskT[:, :-1], vs.maskT[:, 1:]))
vs.maskV = utilities.enforce_boundaries(vs.maskV, settings.enable_cyclic_x)
vs.maskZ = update(vs.maskZ, at[...], vs.maskT)
vs.maskZ = update(
vs.maskZ, at[:-1, :-1], npx.minimum(npx.minimum(vs.maskT[:-1, :-1], vs.maskT[:-1, 1:]), vs.maskT[1:, :-1])
)
vs.maskZ = utilities.enforce_boundaries(vs.maskZ, settings.enable_cyclic_x)
vs.maskW = update(vs.maskW, at[...], vs.maskT)
vs.maskW = update(vs.maskW, at[:, :, :-1], npx.minimum(vs.maskT[:, :, :-1], vs.maskT[:, :, 1:]))
"""
total depth
"""
vs.ht = npx.sum(vs.maskT * vs.dzt[npx.newaxis, npx.newaxis, :], axis=2)
vs.hu = npx.sum(vs.maskU * vs.dzt[npx.newaxis, npx.newaxis, :], axis=2)
vs.hv = npx.sum(vs.maskV * vs.dzt[npx.newaxis, npx.newaxis, :], axis=2)
vs.hur = npx.where(vs.hu != 0, 1 / (vs.hu + 1e-22), 0)
vs.hvr = npx.where(vs.hv != 0, 1 / (vs.hv + 1e-22), 0)
return KernelOutput(
maskT=vs.maskT,
maskU=vs.maskU,
maskV=vs.maskV,
maskW=vs.maskW,
maskZ=vs.maskZ,
ht=vs.ht,
hu=vs.hu,
hv=vs.hv,
hur=vs.hur,
hvr=vs.hvr,
kbot=vs.kbot,
)
@veros_routine
def calc_topo(state):
"""
calulate masks, total depth etc
"""
vs = state.variables
vs.update(calc_topo_kernel(state))
@veros_kernel
def calc_initial_conditions_kernel(state):
vs = state.variables
settings = state.settings
vs.temp = utilities.enforce_boundaries(vs.temp, settings.enable_cyclic_x)
vs.salt = utilities.enforce_boundaries(vs.salt, settings.enable_cyclic_x)
vs.rho = density.get_rho(state, vs.salt, vs.temp, npx.abs(vs.zt)[:, npx.newaxis]) * vs.maskT[..., npx.newaxis]
vs.Hd = (
density.get_dyn_enthalpy(state, vs.salt, vs.temp, npx.abs(vs.zt)[:, npx.newaxis]) * vs.maskT[..., npx.newaxis]
)
vs.int_drhodT = update(
vs.int_drhodT, at[...], density.get_int_drhodT(state, vs.salt, vs.temp, npx.abs(vs.zt)[:, npx.newaxis])
)
vs.int_drhodS = update(
vs.int_drhodS, at[...], density.get_int_drhodS(state, vs.salt, vs.temp, npx.abs(vs.zt)[:, npx.newaxis])
)
fxa = -settings.grav / settings.rho_0 / vs.dzw[npx.newaxis, npx.newaxis, :] * vs.maskW
vs.Nsqr = update(
vs.Nsqr,
at[:, :, :-1, :],
fxa[:, :, :-1, npx.newaxis]
* (
density.get_rho(state, vs.salt[:, :, 1:, :], vs.temp[:, :, 1:, :], npx.abs(vs.zt)[:-1, npx.newaxis])
- vs.rho[:, :, :-1, :]
),
)
vs.Nsqr = update(vs.Nsqr, at[:, :, -1, :], vs.Nsqr[:, :, -2, :])
return KernelOutput(
salt=vs.salt,
temp=vs.temp,
rho=vs.rho,
Hd=vs.Hd,
int_drhodT=vs.int_drhodT,
int_drhodS=vs.int_drhodS,
Nsqr=vs.Nsqr,
)
@veros_routine
def calc_initial_conditions(state):
"""
calculate dyn. enthalp, etc
"""
vs = state.variables
if npx.any(vs.salt < 0.0):
raise RuntimeError("encountered negative salinity")
vs.update(calc_initial_conditions_kernel(state))
@veros_kernel
def ugrid_to_tgrid(state, a):
vs = state.variables
b = update(
a,
at[2:-2, :, :],
(
vs.dxu[2:-2, npx.newaxis, npx.newaxis] * a[2:-2, :, :]
+ vs.dxu[1:-3, npx.newaxis, npx.newaxis] * a[1:-3, :, :]
)
/ (2 * vs.dxt[2:-2, npx.newaxis, npx.newaxis]),
)
return b
@veros_kernel
def vgrid_to_tgrid(state, a):
vs = state.variables
b = update(
a,
at[:, 2:-2, :],
(vs.area_v[:, 2:-2, npx.newaxis] * a[:, 2:-2, :] + vs.area_v[:, 1:-3, npx.newaxis] * a[:, 1:-3, :])
/ (2 * vs.area_t[:, 2:-2, npx.newaxis]),
)
return b
@veros_kernel
def calc_diss_u(state, diss):
vs = state.variables
ks = allocate(state.dimensions, ("xt", "yt"))
ks = update(ks, at[1:-2, 2:-2], npx.maximum(vs.kbot[1:-2, 2:-2], vs.kbot[2:-1, 2:-2]))
diss_u = diffusion.dissipation_on_wgrid(state, diss, ks)
return ugrid_to_tgrid(state, diss_u)
@veros_kernel
def calc_diss_v(state, diss):
vs = state.variables
ks = allocate(state.dimensions, ("xt", "yt"))
ks = update(ks, at[2:-2, 1:-2], npx.maximum(vs.kbot[2:-2, 1:-2], vs.kbot[2:-2, 2:-1]))
diss_v = diffusion.dissipation_on_wgrid(state, diss, ks)
return vgrid_to_tgrid(state, diss_v)
@veros_kernel
def sanity_check(state):
return global_and(npx.all(npx.isfinite(state.variables.u)))
import warnings
from contextlib import contextmanager
from veros import runtime_settings, runtime_state, veros_kernel
class Index:
__slots__ = ()
@staticmethod
def __getitem__(key):
return key
def noop(*args, **kwargs):
pass
@contextmanager
def make_writeable(*arrs):
orig_writeable = [arr.flags.writeable for arr in arrs]
writeable_arrs = []
try:
for arr in arrs:
arr = arr.copy()
arr.flags.writeable = True
writeable_arrs.append(arr)
if len(writeable_arrs) == 1:
yield writeable_arrs[0]
else:
yield writeable_arrs
finally:
for arr, orig_val in zip(writeable_arrs, orig_writeable):
try:
arr.flags.writeable = orig_val
except ValueError:
pass
def update_numpy(arr, at, to):
with make_writeable(arr) as warr:
warr[at] = to
return warr
def update_add_numpy(arr, at, to):
with make_writeable(arr) as warr:
warr[at] += to
return warr
def update_multiply_numpy(arr, at, to):
with make_writeable(arr) as warr:
warr[at] *= to
return warr
def solve_tridiagonal_numpy(a, b, c, d, water_mask, edge_mask):
import numpy as np
from scipy.linalg import lapack
out = np.zeros(a.shape, dtype=a.dtype)
if not np.any(water_mask):
return out
# remove couplings between slices
with make_writeable(a, c) as warr:
a, c = warr
a[edge_mask] = 0
c[..., -1] = 0
sol = lapack.dgtsv(a[water_mask][1:], b[water_mask], c[water_mask][:-1], d[water_mask])[3]
out[water_mask] = sol
return out
def fori_numpy(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
def scan_numpy(f, init, xs, length=None):
import numpy as np
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
@veros_kernel
def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask):
import jax.lax
import jax.numpy as jnp
use_ext = runtime_settings.use_special_tdma
try:
from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT
except ImportError:
if use_ext:
raise
has_ext = False
else:
has_ext = (HAS_CPU_EXT and runtime_settings.device == "cpu") or (
HAS_GPU_EXT and runtime_settings.device == "gpu"
)
if use_ext is None:
if not has_ext:
warnings.warn("Could not use custom TDMA implementation, falling back to pure JAX")
use_ext = False
else:
use_ext = True
if use_ext and not has_ext:
raise RuntimeError("Could not use custom TDMA implementation")
if use_ext:
return tdma(a, b, c, d, water_mask, edge_mask)
a = water_mask * a * jnp.logical_not(edge_mask)
b = jnp.where(water_mask, b, 1.0)
c = water_mask * c
d = water_mask * d
def compute_primes(last_primes, x):
last_cp, last_dp = last_primes
a, b, c, d = x
cp = c / (b - a * last_cp)
dp = (d - a * last_dp) / (b - a * last_cp)
new_primes = (cp, dp)
return new_primes, new_primes
diags_transposed = [jnp.moveaxis(arr, 2, 0) for arr in (a, b, c, d)]
init = jnp.zeros(a.shape[:-1], dtype=a.dtype)
_, primes = jax.lax.scan(compute_primes, (init, init), diags_transposed)
def backsubstitution(last_x, x):
cp, dp = x
new_x = dp - cp * last_x
return new_x, new_x
_, sol = jax.lax.scan(backsubstitution, init, primes, reverse=True)
return jnp.moveaxis(sol, 0, 2)
def update_jax(arr, at, to):
return arr.at[at].set(to)
def update_add_jax(arr, at, to):
return arr.at[at].add(to)
def update_multiply_jax(arr, at, to):
return arr.at[at].multiply(to)
def flush_jax():
import jax
dummy = jax.device_put(0.0) + 0.0
try:
dummy.block_until_ready()
except AttributeError:
# if we are jitting, dummy is not a DeviceArray that we can wait for
pass
numpy = runtime_state.backend_module
if runtime_settings.backend == "numpy":
update = update_numpy
update_add = update_add_numpy
update_multiply = update_multiply_numpy
at = Index()
solve_tridiagonal = solve_tridiagonal_numpy
for_loop = fori_numpy
scan = scan_numpy
flush = noop
elif runtime_settings.backend == "jax":
import jax.lax
update = update_jax
update_add = update_add_jax
update_multiply = update_multiply_jax
at = Index()
solve_tridiagonal = solve_tridiagonal_jax
for_loop = jax.lax.fori_loop
scan = jax.lax.scan
flush = flush_jax
else:
raise ValueError(f"Unrecognized backend {runtime_settings.backend}")
#include <array>
#include <cstddef>
#include <stdexcept>
#include <cuda_runtime.h>
#include "cuda_tdma_kernels.h"
#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
template <typename DTYPE>
__global__ void TridiagKernel(
const DTYPE *a,
const DTYPE *b,
const DTYPE *c,
const DTYPE *d,
DTYPE *cp,
DTYPE *dp,
const int num_systems,
const int system_depth
){
// TDMA algorithm
// Solution is written to dp
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_systems;
idx += blockDim.x * gridDim.x) {
if (idx >= system_depth * num_systems) {
return;
}
int indj = idx;
DTYPE denom;
DTYPE ai;
DTYPE b0 = b[indj];
DTYPE cm1 = c[indj] / b0;
DTYPE dm1 = d[indj] / b0;
cp[indj] = cm1;
dp[indj] = dm1;
// forward pass
for (int j = 0; j < system_depth-1; ++j) {
indj += num_systems;
ai = a[indj];
denom = 1.0f / (b[indj] - ai * cm1);
cm1 = c[indj] * denom;
dm1 = (d[indj] - ai * dm1) * denom;
cp[indj] = cm1;
dp[indj] = dm1;
}
// backward pass
for (int j = 0; j < system_depth-1; ++j) {
indj -= num_systems;
dp[indj] -= cp[indj] * dp[indj + num_systems];
}
}
}
// Unpacks a descriptor object from a byte string.
template <typename T>
const T* UnpackDescriptor(const char* opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Descriptor was not encoded correctly");
}
return reinterpret_cast<const T*>(opaque);
}
template <typename DTYPE>
void CudaTridiag(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const auto& descriptor = *UnpackDescriptor<TridiagDescriptor>(opaque, opaque_len);
const int num_systems = descriptor.num_systems;
const int system_depth = descriptor.system_depth;
const DTYPE* a = reinterpret_cast<const DTYPE*>(buffers[0]);
const DTYPE* b = reinterpret_cast<const DTYPE*>(buffers[1]);
const DTYPE* c = reinterpret_cast<const DTYPE*>(buffers[2]);
const DTYPE* d = reinterpret_cast<const DTYPE*>(buffers[3]);
DTYPE* out = reinterpret_cast<DTYPE*>(buffers[4]);
DTYPE* workspace = reinterpret_cast<DTYPE*>(buffers[5]);
const int BLOCK_SIZE = 128;
const int grid_dim = std::min<int>(1024, (num_systems + BLOCK_SIZE - 1) / BLOCK_SIZE);
TridiagKernel<DTYPE><<<grid_dim, BLOCK_SIZE, 0, stream>>>(a, b, c, d, workspace, out, num_systems, system_depth);
gpuErrchk(cudaPeekAtLastError());
}
void CudaTridiagFloat(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
CudaTridiag<float>(stream, buffers, opaque, opaque_len);
}
void CudaTridiagDouble(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
CudaTridiag<double>(stream, buffers, opaque, opaque_len);
}
#pragma once
#include <cuda_runtime.h>
struct TridiagDescriptor {
int num_systems;
int system_depth;
};
void CudaTridiagFloat(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len);
void CudaTridiagDouble(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len);
# defensive imports since extensions are optional
try:
from veros.core.special import tdma_cython_
except ImportError:
HAS_CPU_EXT = False
else:
HAS_CPU_EXT = True
try:
from veros.core.special import tdma_cuda_
except ImportError:
HAS_GPU_EXT = False
else:
HAS_GPU_EXT = True
import numpy as np
import jax
import jax.numpy as jnp
from jax.core import Primitive, ShapedArray
from jax.lib import xla_client
from jax.interpreters import xla, mlir
import jaxlib.mlir.ir as ir
from jaxlib.mlir.dialects import mhlo
try:
from jax.interpreters.mlir import custom_call # noqa: F401
except ImportError:
# TODO: remove once we require jax > 0.4.16
from jaxlib.hlo_helpers import custom_call as _custom_call
# Recent versions return a structure with a field 'results'. We mock it on
# older versions
from collections import namedtuple
MockResult = namedtuple("MockResult", ["results"])
def custom_call(*args, result_types, **kwargs):
results = _custom_call(*args, out_types=result_types, **kwargs)
return MockResult(results)
if HAS_CPU_EXT:
for kernel_name in (b"tdma_cython_double", b"tdma_cython_float"):
fn = tdma_cython_.cpu_custom_call_targets[kernel_name]
xla_client.register_custom_call_target(kernel_name, fn, platform="cpu")
if HAS_GPU_EXT:
for kernel_name in (b"tdma_cuda_double", b"tdma_cuda_float"):
fn = tdma_cuda_.gpu_custom_call_targets[kernel_name]
xla_client.register_custom_call_target(kernel_name, fn, platform="CUDA")
def as_mhlo_constant(val, dtype):
if isinstance(val, mhlo.ConstantOp):
return val
return mhlo.ConstantOp(
ir.DenseElementsAttr.get(np.array([val], dtype=dtype), type=mlir.dtype_to_ir_type(np.dtype(dtype)))
).result
def tdma(a, b, c, d, interior_mask, edge_mask, device=None):
if device is None:
device = jax.default_backend()
if not a.shape == b.shape == c.shape == d.shape:
raise ValueError("all inputs must have identical shape")
if not a.dtype == b.dtype == c.dtype == d.dtype:
raise ValueError("all inputs must have the same dtype")
if device == "cpu":
system_depths = jnp.sum(interior_mask, axis=-1, dtype="int32")
return tdma_p.bind(a, b, c, d, system_depths)
a = interior_mask * a * jnp.logical_not(edge_mask)
b = jnp.where(interior_mask, b, 1.0)
c = interior_mask * c
d = interior_mask * d
return tdma_p.bind(a, b, c, d, system_depths=None)
def tdma_impl(*args, **kwargs):
return xla.apply_primitive(tdma_p, *args, **kwargs)
def tdma_xla_encode_cpu(ctx, a, b, c, d, system_depths):
# try import again to trigger exception on ImportError
from veros.core.special import tdma_cython_ # noqa: F401
x_aval, *_ = ctx.avals_in
np_dtype = x_aval.dtype
x_type = ir.RankedTensorType(a.type)
dtype = x_type.element_type
dims = x_type.shape
supported_dtypes = (
np.dtype(np.float32),
np.dtype(np.float64),
)
if np_dtype not in supported_dtypes:
raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}")
# compute number of elements to vectorize over
num_systems = 1
for el in dims[:-1]:
num_systems *= el
stride = dims[-1]
out_types = [
ir.RankedTensorType.get(dims, dtype),
ir.RankedTensorType.get((stride,), dtype),
]
if np_dtype is np.dtype(np.float32):
kernel = b"tdma_cython_float"
elif np_dtype is np.dtype(np.float64):
kernel = b"tdma_cython_double"
else:
raise RuntimeError("got unrecognized dtype")
out = custom_call(
kernel,
operands=(
a,
b,
c,
d,
system_depths,
as_mhlo_constant(num_systems, np.int64),
as_mhlo_constant(stride, np.int64),
),
result_types=out_types,
)
return out.results[:-1]
def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):
# try import again to trigger exception on ImportError
from veros.core.special import tdma_cuda_ # noqa: F401
if system_depths is not None:
raise ValueError("TDMA does not support system_depths argument on GPU")
x_aval, *_ = ctx.avals_in
np_dtype = x_aval.dtype
x_type = ir.RankedTensorType(a.type)
dtype = x_type.element_type
dims = x_type.shape
supported_dtypes = (
np.dtype(np.float32),
np.dtype(np.float64),
)
if np_dtype not in supported_dtypes:
raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}")
# compute number of elements to vectorize over
num_systems = 1
for el in dims[:-1]:
num_systems *= el
system_depth = dims[-1]
if np_dtype is np.dtype(np.float32):
kernel = b"tdma_cuda_float"
elif np_dtype is np.dtype(np.float64):
kernel = b"tdma_cuda_double"
else:
raise RuntimeError("got unrecognized dtype")
descriptor = tdma_cuda_.build_tridiag_descriptor(num_systems, system_depth)
ndims = len(dims)
arr_layout = tuple(range(ndims - 2, -1, -1)) + (ndims - 1,)
out_types = [ir.RankedTensorType.get(dims, dtype), ir.RankedTensorType.get(dims, dtype)]
out_layouts = (arr_layout, arr_layout)
out = custom_call(
kernel,
operands=(a, b, c, d),
result_types=out_types,
result_layouts=out_layouts,
operand_layouts=(arr_layout,) * 4,
backend_config=descriptor,
)
return out.results[:-1]
def tdma_abstract_eval(a, b, c, d, system_depths):
return ShapedArray(a.shape, a.dtype)
tdma_p = Primitive("tdma")
tdma_p.def_impl(tdma_impl)
tdma_p.def_abstract_eval(tdma_abstract_eval)
mlir.register_lowering(tdma_p, tdma_xla_encode_cpu, platform="cpu")
mlir.register_lowering(tdma_p, tdma_xla_encode_gpu, platform="cuda")
# cython: language=c++
from cpython.pycapsule cimport PyCapsule_New
cdef extern from "cuda_runtime_api.h":
ctypedef void* cudaStream_t
cdef struct TridiagDescriptor:
int num_systems
int system_depth
cdef extern from "cuda_tdma_kernels.h":
void CudaTridiagFloat(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len)
void CudaTridiagDouble(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len)
cpdef bytes build_tridiag_descriptor(int num_systems, int system_depth):
cdef TridiagDescriptor desc = TridiagDescriptor(num_systems, system_depth)
return bytes((<char*> &desc)[:sizeof(TridiagDescriptor)])
gpu_custom_call_targets = {}
cdef register_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
gpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
register_custom_call_target(b"tdma_cuda_float", <void*>(CudaTridiagFloat))
register_custom_call_target(b"tdma_cuda_double", <void*>(CudaTridiagDouble))
import cython
from cpython.pycapsule cimport PyCapsule_New
from libc.stdint cimport int32_t, int64_t
@cython.cdivision(True)
cdef void _tdma_cython_double(int32_t n, double* a, double* b, double* c, double* d, double* cp, double* dp) nogil:
cdef:
int32_t i
double denom
if n < 1:
return
cp[0] = c[0] / b[0]
dp[0] = d[0] / b[0]
for i in range(1, n):
denom = 1. / (b[i] - a[i] * cp[i - 1])
cp[i] = c[i] * denom
dp[i] = (d[i] - a[i] * dp[i - 1]) * denom
for i in range(n - 2, -1, -1):
dp[i] -= cp[i] * dp[i + 1]
@cython.cdivision(True)
cdef void _tdma_cython_float(int32_t n, float* a, float* b, float* c, float* d, float* cp, float* dp) nogil:
cdef:
int32_t i
float denom
if n < 1:
return
cp[0] = c[0] / b[0]
dp[0] = d[0] / b[0]
for i in range(1, n):
denom = 1. / (b[i] - a[i] * cp[i - 1])
cp[i] = c[i] * denom
dp[i] = (d[i] - a[i] * dp[i - 1]) * denom
for i in range(n - 2, -1, -1):
dp[i] -= cp[i] * dp[i + 1]
cdef void tdma_cython_double(void** out_ptr, void** data_ptr) nogil:
cdef:
int64_t i, j, system_depth, system_start
int64_t ii = 0
# decode inputs
double* a = (<double*>data_ptr[0])
double* b = (<double*>data_ptr[1])
double* c = (<double*>data_ptr[2])
double* d = (<double*>data_ptr[3])
int32_t* system_depths = (<int32_t*>data_ptr[4])
int64_t num_systems = (<int64_t*>data_ptr[5])[0]
int64_t stride = (<int64_t*>data_ptr[6])[0]
double* out = (<double*>out_ptr[0])
double* workspace = (<double*>out_ptr[1])
for i in range(num_systems):
system_depth = system_depths[i]
system_start = stride - system_depth
for j in range(system_start):
out[ii + j] = 0.
_tdma_cython_double(
system_depth,
&a[ii + system_start],
&b[ii + system_start],
&c[ii + system_start],
&d[ii + system_start],
workspace,
&out[ii + system_start],
)
ii += stride
cdef void tdma_cython_float(void** out_ptr, void** data_ptr) nogil:
cdef:
int64_t i, j, system_depth, system_start
int64_t ii = 0
# decode inputs
float* a = (<float*>data_ptr[0])
float* b = (<float*>data_ptr[1])
float* c = (<float*>data_ptr[2])
float* d = (<float*>data_ptr[3])
int32_t* system_depths = (<int32_t*>data_ptr[4])
int64_t num_systems = (<int64_t*>data_ptr[5])[0]
int64_t stride = (<int64_t*>data_ptr[6])[0]
float* out = (<float*>out_ptr[0])
float* workspace = (<float*>out_ptr[1])
for i in range(num_systems):
system_depth = system_depths[i]
system_start = stride - system_depth
for j in range(system_start):
out[ii + j] = 0.0
_tdma_cython_float(
system_depth,
&a[ii + system_start],
&b[ii + system_start],
&c[ii + system_start],
&d[ii + system_start],
workspace,
&out[ii + system_start],
)
ii += stride
cpu_custom_call_targets = {}
cdef register_custom_call_target(fn_name, void* fn):
cdef const char* name = 'xla._CUSTOM_CALL_TARGET'
cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL)
register_custom_call_target(b'tdma_cython_double', <void*>(tdma_cython_double))
register_custom_call_target(b'tdma_cython_float', <void*>(tdma_cython_float))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment