Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
from veros.core.operators import numpy as npx
from veros import veros_routine, veros_kernel, KernelOutput
from veros.distributed import global_sum
from veros.variables import allocate
from veros.core import advection, diffusion, isoneutral, density, utilities
from veros.core.operators import update, update_add, at
@veros_kernel
def advect_tracer(state, tr):
"""
calculate time tendency of a tracer due to advection
"""
vs = state.variables
settings = state.settings
if settings.enable_superbee_advection:
flux_east, flux_north, flux_top = advection.adv_flux_superbee(state, tr)
else:
flux_east, flux_north, flux_top = advection.adv_flux_2nd(state, tr)
dtr = allocate(state.dimensions, ("xt", "yt", "zt"))
dtr = update(
dtr,
at[2:-2, 2:-2, :],
vs.maskT[2:-2, 2:-2, :]
* (
-(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
- (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
dtr = update_add(dtr, at[:, :, 0], -1 * vs.maskT[:, :, 0] * flux_top[:, :, 0] / vs.dzt[0])
dtr = update_add(
dtr, at[:, :, 1:], -1 * vs.maskT[:, :, 1:] * (flux_top[:, :, 1:] - flux_top[:, :, :-1]) / vs.dzt[1:]
)
return dtr
@veros_kernel
def advect_temperature(state):
"""
integrate temperature
"""
vs = state.variables
dtr = advect_tracer(state, vs.temp[..., vs.tau])
vs.dtemp = update(vs.dtemp, at[..., vs.tau], dtr)
return KernelOutput(dtemp=vs.dtemp)
@veros_kernel
def advect_salinity(state):
"""
integrate salinity
"""
vs = state.variables
dtr = advect_tracer(state, vs.salt[..., vs.tau])
vs.dsalt = update(vs.dsalt, at[..., vs.tau], dtr)
return KernelOutput(dsalt=vs.dsalt)
@veros_kernel
def calc_eq_of_state(state, n):
"""
calculate density, stability frequency, dynamic enthalpy and derivatives
for time level n from temperature and salinity
"""
vs = state.variables
settings = state.settings
salt = vs.salt[..., n]
temp = vs.temp[..., n]
press = npx.abs(vs.zt)
"""
calculate new density
"""
vs.rho = update(vs.rho, at[..., n], density.get_rho(state, salt, temp, press) * vs.maskT)
"""
calculate new potential density
"""
vs.prho = update(vs.prho, at[...], density.get_potential_rho(state, salt, temp) * vs.maskT)
"""
calculate new dynamic enthalpy and derivatives
"""
if settings.enable_conserve_energy:
vs.Hd = update(vs.Hd, at[..., n], density.get_dyn_enthalpy(state, salt, temp, press) * vs.maskT)
vs.int_drhodT = update(vs.int_drhodT, at[..., n], density.get_int_drhodT(state, salt, temp, press))
vs.int_drhodS = update(vs.int_drhodS, at[..., n], density.get_int_drhodS(state, salt, temp, press))
"""
new stability frequency
"""
fxa = -settings.grav / settings.rho_0 / vs.dzw[npx.newaxis, npx.newaxis, :-1] * vs.maskW[:, :, :-1]
vs.Nsqr = update(
vs.Nsqr,
at[:, :, :-1, n],
fxa * (density.get_rho(state, salt[:, :, 1:], temp[:, :, 1:], press[:-1]) - vs.rho[:, :, :-1, n]),
)
vs.Nsqr = update(vs.Nsqr, at[:, :, -1, n], vs.Nsqr[:, :, -2, n])
return KernelOutput(
rho=vs.rho, prho=vs.prho, Hd=vs.Hd, int_drhodT=vs.int_drhodT, int_drhodS=vs.int_drhodS, Nsqr=vs.Nsqr
)
@veros_kernel
def advect_temp_salt_enthalpy(state):
"""
integrate temperature and salinity and diagnose sources of dynamic enthalpy
"""
vs = state.variables
settings = state.settings
vs.dtemp = advect_temperature(state).dtemp
vs.dsalt = advect_salinity(state).dsalt
if settings.enable_conserve_energy:
"""
advection of dynamic enthalpy
"""
if settings.enable_superbee_advection:
flux_east, flux_north, flux_top = advection.adv_flux_superbee(state, vs.Hd[:, :, :, vs.tau])
else:
flux_east, flux_north, flux_top = advection.adv_flux_2nd(state, vs.Hd[:, :, :, vs.tau])
vs.dHd = update(
vs.dHd,
at[2:-2, 2:-2, :, vs.tau],
vs.maskT[2:-2, 2:-2, :]
* (
-(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
- (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
vs.dHd = update_add(vs.dHd, at[:, :, 0, vs.tau], -1 * vs.maskT[:, :, 0] * flux_top[:, :, 0] / vs.dzt[0])
vs.dHd = update_add(
vs.dHd,
at[:, :, 1:, vs.tau],
-1 * vs.maskT[:, :, 1:] * (flux_top[:, :, 1:] - flux_top[:, :, :-1]) / vs.dzt[npx.newaxis, npx.newaxis, 1:],
)
"""
changes in dyn. Enthalpy due to advection
"""
diss = allocate(state.dimensions, ("xt", "yt", "zt"))
diss = update(
diss,
at[2:-2, 2:-2, :],
settings.grav
/ settings.rho_0
* (
-vs.int_drhodT[2:-2, 2:-2, :, vs.tau] * vs.dtemp[2:-2, 2:-2, :, vs.tau]
- vs.int_drhodS[2:-2, 2:-2, :, vs.tau] * vs.dsalt[2:-2, 2:-2, :, vs.tau]
)
- vs.dHd[2:-2, 2:-2, :, vs.tau],
)
"""
contribution by vertical advection is - g rho w / rho0, substract this also
"""
diss = update_add(
diss,
at[:, :, :-1],
-0.25
* settings.grav
/ settings.rho_0
* vs.w[:, :, :-1, vs.tau]
* (vs.rho[:, :, :-1, vs.tau] + vs.rho[:, :, 1:, vs.tau])
* vs.dzw[npx.newaxis, npx.newaxis, :-1]
/ vs.dzt[npx.newaxis, npx.newaxis, :-1],
)
diss = update_add(
diss,
at[:, :, 1:],
-0.25
* settings.grav
/ settings.rho_0
* vs.w[:, :, :-1, vs.tau]
* (vs.rho[:, :, 1:, vs.tau] + vs.rho[:, :, :-1, vs.tau])
* vs.dzw[npx.newaxis, npx.newaxis, :-1]
/ vs.dzt[npx.newaxis, npx.newaxis, 1:],
)
if settings.enable_conserve_energy and settings.enable_tke:
"""
dissipation by advection interpolated on W-grid
"""
vs.P_diss_adv = diffusion.dissipation_on_wgrid(state, diss, vs.kbot)
"""
distribute P_diss_adv over domain, prevent draining of TKE
"""
fxa = npx.sum(
vs.area_t[2:-2, 2:-2, npx.newaxis]
* vs.P_diss_adv[2:-2, 2:-2, :-1]
* vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskW[2:-2, 2:-2, :-1]
) + npx.sum(0.5 * vs.area_t[2:-2, 2:-2] * vs.P_diss_adv[2:-2, 2:-2, -1] * vs.dzw[-1] * vs.maskW[2:-2, 2:-2, -1])
tke_mask = vs.tke[2:-2, 2:-2, :-1, vs.tau] > 0.0
fxb = npx.sum(
vs.area_t[2:-2, 2:-2, npx.newaxis]
* vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskW[2:-2, 2:-2, :-1]
* tke_mask
) + npx.sum(0.5 * vs.area_t[2:-2, 2:-2] * vs.dzw[-1] * vs.maskW[2:-2, 2:-2, -1])
fxa = global_sum(fxa)
fxb = global_sum(fxb)
vs.P_diss_adv = update(vs.P_diss_adv, at[2:-2, 2:-2, :-1], fxa / fxb * tke_mask)
vs.P_diss_adv = update(vs.P_diss_adv, at[2:-2, 2:-2, -1], fxa / fxb)
"""
Adam Bashforth time stepping for advection
"""
vs.temp = update(
vs.temp,
at[:, :, :, vs.taup1],
vs.temp[:, :, :, vs.tau]
+ settings.dt_tracer
* ((1.5 + settings.AB_eps) * vs.dtemp[:, :, :, vs.tau] - (0.5 + settings.AB_eps) * vs.dtemp[:, :, :, vs.taum1])
* vs.maskT,
)
vs.salt = update(
vs.salt,
at[:, :, :, vs.taup1],
vs.salt[:, :, :, vs.tau]
+ settings.dt_tracer
* ((1.5 + settings.AB_eps) * vs.dsalt[:, :, :, vs.tau] - (0.5 + settings.AB_eps) * vs.dsalt[:, :, :, vs.taum1])
* vs.maskT,
)
return KernelOutput(
temp=vs.temp, salt=vs.salt, dtemp=vs.dtemp, dsalt=vs.dsalt, dHd=vs.dHd, P_diss_adv=vs.P_diss_adv
)
@veros_kernel
def vertmix_tempsalt(state):
"""
vertical mixing of temperature and salinity
"""
vs = state.variables
settings = state.settings
vs.dtemp_vmix = update(vs.dtemp_vmix, at[...], vs.temp[:, :, :, vs.taup1])
vs.dsalt_vmix = update(vs.dsalt_vmix, at[...], vs.salt[:, :, :, vs.taup1])
a_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
b_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
c_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
d_tri = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
delta = allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2]
_, water_mask, edge_mask = utilities.create_water_masks(vs.kbot[2:-2, 2:-2], settings.nz)
delta = update(
delta, at[:, :, :-1], settings.dt_tracer / vs.dzw[npx.newaxis, npx.newaxis, :-1] * vs.kappaH[2:-2, 2:-2, :-1]
)
delta = update(delta, at[:, :, -1], 0.0)
a_tri = update(a_tri, at[:, :, 1:], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri = update(b_tri, at[:, :, 1:], 1 + (delta[:, :, 1:] + delta[:, :, :-1]) / vs.dzt[npx.newaxis, npx.newaxis, 1:])
b_tri_edge = 1 + delta / vs.dzt[npx.newaxis, npx.newaxis, :]
c_tri = update(c_tri, at[:, :, :-1], -delta[:, :, :-1] / vs.dzt[npx.newaxis, npx.newaxis, :-1])
d_tri = vs.temp[2:-2, 2:-2, :, vs.taup1]
d_tri = update_add(d_tri, at[:, :, -1], settings.dt_tracer * vs.forc_temp_surface[2:-2, 2:-2] / vs.dzt[-1])
sol = utilities.solve_implicit(a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, edge_mask=edge_mask)
vs.temp = update(vs.temp, at[2:-2, 2:-2, :, vs.taup1], npx.where(water_mask, sol, vs.temp[2:-2, 2:-2, :, vs.taup1]))
d_tri = vs.salt[2:-2, 2:-2, :, vs.taup1]
d_tri = update_add(d_tri, at[:, :, -1], settings.dt_tracer * vs.forc_salt_surface[2:-2, 2:-2] / vs.dzt[-1])
sol = utilities.solve_implicit(a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, edge_mask=edge_mask)
vs.salt = update(vs.salt, at[2:-2, 2:-2, :, vs.taup1], npx.where(water_mask, sol, vs.salt[2:-2, 2:-2, :, vs.taup1]))
vs.dtemp_vmix = (vs.temp[:, :, :, vs.taup1] - vs.dtemp_vmix) / settings.dt_tracer
vs.dsalt_vmix = (vs.salt[:, :, :, vs.taup1] - vs.dsalt_vmix) / settings.dt_tracer
"""
boundary exchange
"""
vs.temp = update(
vs.temp, at[..., vs.taup1], utilities.enforce_boundaries(vs.temp[..., vs.taup1], settings.enable_cyclic_x)
)
vs.salt = update(
vs.salt, at[..., vs.taup1], utilities.enforce_boundaries(vs.salt[..., vs.taup1], settings.enable_cyclic_x)
)
return KernelOutput(dtemp_vmix=vs.dtemp_vmix, temp=vs.temp, dsalt_vmix=vs.dsalt_vmix, salt=vs.salt)
@veros_kernel
def surf_densityf(state):
"""
surface density flux
"""
vs = state.variables
vs.forc_rho_surface = vs.maskT[:, :, -1] * (
density.get_drhodT(state, vs.salt[:, :, -1, vs.taup1], vs.temp[:, :, -1, vs.taup1], npx.abs(vs.zt[-1]))
* vs.forc_temp_surface
+ density.get_drhodS(state, vs.salt[:, :, -1, vs.taup1], vs.temp[:, :, -1, vs.taup1], npx.abs(vs.zt[-1]))
* vs.forc_salt_surface
)
return KernelOutput(forc_rho_surface=vs.forc_rho_surface)
@veros_kernel
def diag_P_diss_v(state):
vs = state.variables
settings = state.settings
vs.P_diss_v = update(vs.P_diss_v, at[...], 0.0)
aloc = allocate(state.dimensions, ("xt", "yt", "zt"))
if settings.enable_conserve_energy:
"""
diagnose dissipation of dynamic enthalpy by vertical mixing
"""
fxa = (-vs.int_drhodT[2:-2, 2:-2, 1:, vs.taup1] + vs.int_drhodT[2:-2, 2:-2, :-1, vs.taup1]) / vs.dzw[
npx.newaxis, npx.newaxis, :-1
]
vs.P_diss_v = update_add(
vs.P_diss_v,
at[2:-2, 2:-2, :-1],
-settings.grav
/ settings.rho_0
* fxa
* vs.kappaH[2:-2, 2:-2, :-1]
* (vs.temp[2:-2, 2:-2, 1:, vs.taup1] - vs.temp[2:-2, 2:-2, :-1, vs.taup1])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskW[2:-2, 2:-2, :-1],
)
fxa = (-vs.int_drhodS[2:-2, 2:-2, 1:, vs.taup1] + vs.int_drhodS[2:-2, 2:-2, :-1, vs.taup1]) / vs.dzw[
npx.newaxis, npx.newaxis, :-1
]
vs.P_diss_v = update_add(
vs.P_diss_v,
at[2:-2, 2:-2, :-1],
-settings.grav
/ settings.rho_0
* fxa
* vs.kappaH[2:-2, 2:-2, :-1]
* (vs.salt[2:-2, 2:-2, 1:, vs.taup1] - vs.salt[2:-2, 2:-2, :-1, vs.taup1])
/ vs.dzw[npx.newaxis, npx.newaxis, :-1]
* vs.maskW[2:-2, 2:-2, :-1],
)
fxa = 2 * vs.int_drhodT[2:-2, 2:-2, -1, vs.taup1] / vs.dzw[-1]
vs.P_diss_v = update_add(
vs.P_diss_v,
at[2:-2, 2:-2, -1],
-settings.grav / settings.rho_0 * fxa * vs.forc_temp_surface[2:-2, 2:-2] * vs.maskW[2:-2, 2:-2, -1],
)
fxa = 2 * vs.int_drhodS[2:-2, 2:-2, -1, vs.taup1] / vs.dzw[-1]
vs.P_diss_v = update_add(
vs.P_diss_v,
at[2:-2, 2:-2, -1],
-settings.grav / settings.rho_0 * fxa * vs.forc_salt_surface[2:-2, 2:-2] * vs.maskW[2:-2, 2:-2, -1],
)
if settings.enable_conserve_energy:
"""
determine effect due to nonlinear equation of state
"""
aloc = update(aloc, at[:, :, :-1], vs.kappaH[:, :, :-1] * vs.Nsqr[:, :, :-1, vs.taup1])
vs.P_diss_nonlin = update(vs.P_diss_nonlin, at[:, :, :-1], vs.P_diss_v[:, :, :-1] - aloc[:, :, :-1])
vs.P_diss_v = update(vs.P_diss_v, at[:, :, :-1], aloc[:, :, :-1])
else:
"""
diagnose N^2 vs. kappaH, i.e. exchange of pot. energy with TKE
"""
vs.P_diss_v = update(vs.P_diss_v, at[:, :, :-1], vs.kappaH[:, :, :-1] * vs.Nsqr[:, :, :-1, vs.taup1])
vs.P_diss_v = update(
vs.P_diss_v, at[:, :, -1], -vs.forc_rho_surface * vs.maskT[:, :, -1] * settings.grav / settings.rho_0
)
return KernelOutput(P_diss_v=vs.P_diss_v, P_diss_nonlin=vs.P_diss_nonlin)
@veros_routine
def thermodynamics(state):
"""
integrate temperature and salinity and diagnose sources of dynamic enthalpy
"""
"""
Advection tendencies for temperature, salinity and dynamic enthalpy
"""
vs = state.variables
settings = state.settings
vs.update(advect_temp_salt_enthalpy(state))
"""
horizontal diffusion
"""
with state.timers["isoneutral"]:
if settings.enable_hor_diffusion:
vs.update(diffusion.tempsalt_diffusion(state))
if settings.enable_biharmonic_mixing:
vs.update(diffusion.tempsalt_biharmonic(state))
"""
sources like restoring zones, etc
"""
if settings.enable_tempsalt_sources:
vs.update(diffusion.tempsalt_sources(state))
"""
isopycnal diffusion
"""
if settings.enable_neutral_diffusion:
vs.P_diss_iso = update(vs.P_diss_iso, at[...], 0.0)
vs.dtemp_iso = update(vs.dtemp_iso, at[...], 0.0)
vs.dsalt_iso = update(vs.dsalt_iso, at[...], 0.0)
vs.update(isoneutral.isoneutral_diffusion_pre(state))
vs.update(isoneutral.isoneutral_diffusion(state, tr=vs.temp, istemp=True))
vs.update(isoneutral.isoneutral_diffusion(state, tr=vs.salt, istemp=False))
if settings.enable_skew_diffusion:
vs.P_diss_skew = update(vs.P_diss_skew, at[...], 0.0)
vs.update(isoneutral.isoneutral_skew_diffusion(state, tr=vs.temp, istemp=True))
vs.update(isoneutral.isoneutral_skew_diffusion(state, tr=vs.salt, istemp=False))
with state.timers["vmix"]:
vs.update(vertmix_tempsalt(state))
with state.timers["eq_of_state"]:
vs.update(calc_eq_of_state(state, vs.taup1))
"""
surface density flux
"""
vs.update(surf_densityf(state))
with state.timers["vmix"]:
vs.update(diag_P_diss_v(state))
from veros import veros_kernel, veros_routine, KernelOutput
from veros.variables import allocate
from veros.core import advection, utilities
from veros.core.operators import update, update_add, at, for_loop, numpy as npx
@veros_routine
def set_tke_diffusivities(state):
vs = state.variables
settings = state.settings
if settings.enable_tke:
tke_diff_out = set_tke_diffusivities_kernel(state)
vs.update(tke_diff_out)
else:
vs.kappaM = update(vs.kappaM, at[...], settings.kappaM_0)
vs.kappaH = npx.where(vs.Nsqr[..., vs.tau] < 0.0, 1.0, settings.kappaH_0)
@veros_kernel
def set_tke_diffusivities_kernel(state):
"""
set vertical diffusivities based on TKE model
"""
vs = state.variables
settings = state.settings
Rinumber = allocate(state.dimensions, ("xt", "yt", "zt"))
vs.sqrttke = npx.sqrt(npx.maximum(0.0, vs.tke[:, :, :, vs.tau]))
"""
calculate buoyancy length scale
"""
vs.mxl = npx.sqrt(2) * vs.sqrttke / npx.sqrt(npx.maximum(1e-12, vs.Nsqr[:, :, :, vs.tau])) * vs.maskW
"""
apply limits for mixing length
"""
if settings.tke_mxl_choice == 1:
"""
bounded by the distance to surface/bottom
"""
vs.mxl = npx.minimum(
npx.minimum(vs.mxl, -vs.zw[npx.newaxis, npx.newaxis, :] + vs.dzw[npx.newaxis, npx.newaxis, :] * 0.5),
vs.ht[:, :, npx.newaxis] + vs.zw[npx.newaxis, npx.newaxis, :],
)
vs.mxl = npx.maximum(vs.mxl, settings.mxl_min)
elif settings.tke_mxl_choice == 2:
"""
bound length scale as in mitgcm/OPA code
"""
nz = state.dimensions["zt"]
def backwards_pass(kinv, mxl):
k = nz - kinv - 1
return update(mxl, at[:, :, k], npx.minimum(mxl[:, :, k], mxl[:, :, k + 1] + vs.dzt[k + 1]))
vs.mxl = for_loop(1, nz, backwards_pass, vs.mxl)
vs.mxl = update(vs.mxl, at[:, :, -1], npx.minimum(vs.mxl[:, :, -1], settings.mxl_min + vs.dzt[-1]))
def forwards_pass(k, mxl):
return update(mxl, at[:, :, k], npx.minimum(mxl[:, :, k], mxl[:, :, k - 1] + vs.dzt[k]))
vs.mxl = for_loop(1, nz, forwards_pass, vs.mxl)
vs.mxl = npx.maximum(vs.mxl, settings.mxl_min)
else:
raise ValueError("unknown mixing length choice in tke_mxl_choice")
"""
calculate viscosity and diffusivity based on Prandtl number
"""
vs.K_diss_v = utilities.enforce_boundaries(vs.K_diss_v, settings.enable_cyclic_x)
vs.kappaM = update(vs.kappaM, at[...], npx.minimum(settings.kappaM_max, settings.c_k * vs.mxl * vs.sqrttke))
Rinumber = update(
Rinumber, at[...], vs.Nsqr[:, :, :, vs.tau] / npx.maximum(vs.K_diss_v / npx.maximum(1e-12, vs.kappaM), 1e-12)
)
if settings.enable_idemix:
Rinumber = update(
Rinumber,
at[...],
npx.minimum(
Rinumber,
vs.kappaM * vs.Nsqr[:, :, :, vs.tau] / npx.maximum(1e-12, vs.alpha_c * vs.E_iw[:, :, :, vs.tau] ** 2),
),
)
if settings.enable_Prandtl_tke:
vs.Prandtlnumber = npx.maximum(1.0, npx.minimum(10, 6.6 * Rinumber))
else:
vs.Prandtlnumber = update(vs.Prandtlnumber, at[...], settings.Prandtl_tke0)
vs.kappaH = npx.maximum(settings.kappaH_min, vs.kappaM / vs.Prandtlnumber)
if settings.enable_kappaH_profile:
# Correct diffusivity according to
# Bryan, K., and L. J. Lewis, 1979:
# A water mass model of the world ocean. J. Geophys. Res., 84, 2503–2517.
# It mainly modifies kappaH within 20S - 20N deg. belt
vs.kappaH = npx.maximum(
vs.kappaH,
(0.8 + 1.05 / settings.pi * npx.arctan((-vs.zw[npx.newaxis, npx.newaxis, :] - 2500.0) / 222.2)) * 1e-4,
)
vs.kappaM = npx.maximum(settings.kappaM_min, vs.kappaM)
return KernelOutput(
sqrttke=vs.sqrttke,
mxl=vs.mxl,
kappaM=vs.kappaM,
kappaH=vs.kappaH,
Prandtlnumber=vs.Prandtlnumber,
K_diss_v=vs.K_diss_v,
)
@veros_routine
def integrate_tke(state):
vs = state.variables
tke_out = integrate_tke_kernel(state)
vs.update(tke_out)
@veros_kernel
def integrate_tke_kernel(state):
"""
integrate Tke equation on W grid with surface flux boundary condition
"""
vs = state.variables
settings = state.settings
conditional_outputs = {}
flux_east = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_north = allocate(state.dimensions, ("xt", "yt", "zt"))
flux_top = allocate(state.dimensions, ("xt", "yt", "zt"))
dt_tke = settings.dt_mom # use momentum time step to prevent spurious oscillations
"""
Sources and sinks by vertical friction, vertical mixing, and non-conservative advection
"""
forc = vs.K_diss_v - vs.P_diss_v - vs.P_diss_adv
"""
store transfer due to vertical mixing from dyn. enthalpy by non-linear eq.of
state either to TKE or to heat
"""
if not settings.enable_store_cabbeling_heat:
forc = forc - vs.P_diss_nonlin
"""
transfer part of dissipation of EKE to TKE
"""
if settings.enable_eke:
forc = forc + vs.eke_diss_tke
if settings.enable_idemix:
"""
transfer dissipation of internal waves to TKE
"""
forc = forc + vs.iw_diss
"""
store bottom friction either in TKE or internal waves
"""
if settings.enable_store_bottom_friction_tke:
forc = forc + vs.K_diss_bot
else: # short-cut without idemix
if settings.enable_eke:
forc = forc + vs.eke_diss_iw
else: # and without EKE model
if settings.enable_store_cabbeling_heat:
forc = forc + vs.K_diss_gm + vs.K_diss_h - vs.P_diss_skew - vs.P_diss_hmix - vs.P_diss_iso
else:
forc = forc + vs.K_diss_gm + vs.K_diss_h - vs.P_diss_skew
forc = forc + vs.K_diss_bot
"""
vertical mixing and dissipation of TKE
"""
_, water_mask, edge_mask = utilities.create_water_masks(vs.kbot[2:-2, 2:-2], settings.nz)
a_tri, b_tri, c_tri, d_tri, delta = (
allocate(state.dimensions, ("xt", "yt", "zt"))[2:-2, 2:-2, :] for _ in range(5)
)
delta = update(
delta,
at[:, :, :-1],
dt_tke
/ vs.dzt[npx.newaxis, npx.newaxis, 1:]
* settings.alpha_tke
* 0.5
* (vs.kappaM[2:-2, 2:-2, :-1] + vs.kappaM[2:-2, 2:-2, 1:]),
)
a_tri = update(a_tri, at[:, :, 1:-1], -delta[:, :, :-2] / vs.dzw[npx.newaxis, npx.newaxis, 1:-1])
a_tri = update(a_tri, at[:, :, -1], -delta[:, :, -2] / (0.5 * vs.dzw[-1]))
b_tri = update(
b_tri,
at[:, :, 1:-1],
1
+ (delta[:, :, 1:-1] + delta[:, :, :-2]) / vs.dzw[npx.newaxis, npx.newaxis, 1:-1]
+ dt_tke * settings.c_eps * vs.sqrttke[2:-2, 2:-2, 1:-1] / vs.mxl[2:-2, 2:-2, 1:-1],
)
b_tri = update(
b_tri,
at[:, :, -1],
1
+ delta[:, :, -2] / (0.5 * vs.dzw[-1])
+ dt_tke * settings.c_eps / vs.mxl[2:-2, 2:-2, -1] * vs.sqrttke[2:-2, 2:-2, -1],
)
b_tri_edge = (
1
+ delta / vs.dzw[npx.newaxis, npx.newaxis, :]
+ dt_tke * settings.c_eps / vs.mxl[2:-2, 2:-2, :] * vs.sqrttke[2:-2, 2:-2, :]
)
c_tri = update(c_tri, at[:, :, :-1], -delta[:, :, :-1] / vs.dzw[npx.newaxis, npx.newaxis, :-1])
d_tri = update(d_tri, at[...], vs.tke[2:-2, 2:-2, :, vs.tau] + dt_tke * forc[2:-2, 2:-2, :])
d_tri = update_add(d_tri, at[:, :, -1], dt_tke * vs.forc_tke_surface[2:-2, 2:-2] / (0.5 * vs.dzw[-1]))
sol = utilities.solve_implicit(a_tri, b_tri, c_tri, d_tri, water_mask, b_edge=b_tri_edge, edge_mask=edge_mask)
vs.tke = update(vs.tke, at[2:-2, 2:-2, :, vs.taup1], npx.where(water_mask, sol, vs.tke[2:-2, 2:-2, :, vs.taup1]))
"""
store tke dissipation for diagnostics
"""
vs.tke_diss = settings.c_eps / vs.mxl * vs.sqrttke * vs.tke[:, :, :, vs.taup1]
"""
Add TKE if surface density flux drains TKE in uppermost box
"""
mask = vs.tke[2:-2, 2:-2, -1, vs.taup1] < 0.0
vs.tke_surf_corr = update(
vs.tke_surf_corr,
at[2:-2, 2:-2],
npx.where(mask, -vs.tke[2:-2, 2:-2, -1, vs.taup1] * 0.5 * vs.dzw[-1] / dt_tke, 0.0),
)
vs.tke = update(vs.tke, at[2:-2, 2:-2, -1, vs.taup1], npx.maximum(0.0, vs.tke[2:-2, 2:-2, -1, vs.taup1]))
if settings.enable_tke_hor_diffusion:
"""
add tendency due to lateral diffusion
"""
flux_east = update(
flux_east,
at[:-1, :, :],
settings.K_h_tke
* (vs.tke[1:, :, :, vs.tau] - vs.tke[:-1, :, :, vs.tau])
/ (vs.cost[npx.newaxis, :, npx.newaxis] * vs.dxu[:-1, npx.newaxis, npx.newaxis])
* vs.maskU[:-1, :, :],
)
flux_north = update(
flux_north,
at[:, :-1, :],
settings.K_h_tke
* (vs.tke[:, 1:, :, vs.tau] - vs.tke[:, :-1, :, vs.tau])
/ vs.dyu[npx.newaxis, :-1, npx.newaxis]
* vs.maskV[:, :-1, :]
* vs.cosu[npx.newaxis, :-1, npx.newaxis],
)
flux_north = update(flux_north, at[:, -1, :], 0.0)
vs.tke = update_add(
vs.tke,
at[2:-2, 2:-2, :, vs.taup1],
dt_tke
* vs.maskW[2:-2, 2:-2, :]
* (
(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
+ (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
"""
add tendency due to advection
"""
if settings.enable_tke_superbee_advection:
flux_east, flux_north, flux_top = advection.adv_flux_superbee_wgrid(state, vs.tke[:, :, :, vs.tau])
if settings.enable_tke_upwind_advection:
flux_east, flux_north, flux_top = advection.adv_flux_upwind_wgrid(state, vs.tke[:, :, :, vs.tau])
if settings.enable_tke_superbee_advection or settings.enable_tke_upwind_advection:
vs.dtke = update(
vs.dtke,
at[2:-2, 2:-2, :, vs.tau],
vs.maskW[2:-2, 2:-2, :]
* (
-(flux_east[2:-2, 2:-2, :] - flux_east[1:-3, 2:-2, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
- (flux_north[2:-2, 2:-2, :] - flux_north[2:-2, 1:-3, :])
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dyt[npx.newaxis, 2:-2, npx.newaxis])
),
)
vs.dtke = update_add(vs.dtke, at[:, :, 0, vs.tau], -flux_top[:, :, 0] / vs.dzw[0])
vs.dtke = update_add(
vs.dtke, at[:, :, 1:-1, vs.tau], -(flux_top[:, :, 1:-1] - flux_top[:, :, :-2]) / vs.dzw[1:-1]
)
vs.dtke = update_add(
vs.dtke, at[:, :, -1, vs.tau], -(flux_top[:, :, -1] - flux_top[:, :, -2]) / (0.5 * vs.dzw[-1])
)
"""
Adam Bashforth time stepping
"""
vs.tke = update_add(
vs.tke,
at[:, :, :, vs.taup1],
settings.dt_tracer
* (
(1.5 + settings.AB_eps) * vs.dtke[:, :, :, vs.tau]
- (0.5 + settings.AB_eps) * vs.dtke[:, :, :, vs.taum1]
),
)
conditional_outputs.update(dtke=vs.dtke)
return KernelOutput(tke=vs.tke, tke_surf_corr=vs.tke_surf_corr, tke_diss=vs.tke_diss, **conditional_outputs)
from veros.core.operators import numpy as npx
from veros import veros_kernel
from veros.core.operators import update, at, solve_tridiagonal
@veros_kernel(static_args=("enable_cyclic_x", "local"))
def enforce_boundaries(arr, enable_cyclic_x, local=False):
from veros import runtime_state as rst
from veros.routines import CURRENT_CONTEXT
if rst.proc_num == 1 or not CURRENT_CONTEXT.is_dist_safe or local:
if enable_cyclic_x:
arr = update(arr, at[-2:, ...], arr[2:4, ...])
arr = update(arr, at[:2, ...], arr[-4:-2, ...])
return arr
from veros.distributed import exchange_overlap
arr = exchange_overlap(arr, ["xt", "yt"], cyclic=enable_cyclic_x)
return arr
@veros_kernel
def pad_z_edges(array):
"""
Pads the z-axis of an array by repeating its edge values
"""
if array.ndim == 1:
newarray = npx.pad(array, 1, mode="edge")
elif array.ndim >= 3:
newarray = npx.pad(array, ((0, 0), (0, 0), (1, 1)), mode="edge")
else:
raise ValueError("Array to pad needs to have 1 or at least 3 dimensions")
return newarray
@veros_kernel(static_args=("nz"))
def create_water_masks(ks, nz):
ks = ks - 1
land_mask = ks >= 0
water_mask = npx.logical_and(
land_mask[:, :, npx.newaxis], npx.arange(nz)[npx.newaxis, npx.newaxis, :] >= ks[:, :, npx.newaxis]
)
edge_mask = npx.logical_and(
land_mask[:, :, npx.newaxis], npx.arange(nz)[npx.newaxis, npx.newaxis, :] == ks[:, :, npx.newaxis]
)
return land_mask, water_mask, edge_mask
@veros_kernel
def solve_implicit(a, b, c, d, water_mask, edge_mask, b_edge=None, d_edge=None):
if b_edge is not None:
b = npx.where(edge_mask, b_edge, b)
if d_edge is not None:
d = npx.where(edge_mask, d_edge, d)
return solve_tridiagonal(a, b, c, d, water_mask, edge_mask)
from veros.diagnostics.api import create_default_diagnostics, initialize, diagnose, output # noqa: F401
from veros import logger, time
def create_default_diagnostics(state):
# do not import these at module level to make sure core import is deferred
from veros.diagnostics.averages import Averages
from veros.diagnostics.cfl_monitor import CFLMonitor
from veros.diagnostics.energy import Energy
from veros.diagnostics.overturning import Overturning
from veros.diagnostics.snapshot import Snapshot
from veros.diagnostics.tracer_monitor import TracerMonitor
return {Diag.name: Diag(state) for Diag in (Averages, CFLMonitor, Energy, Overturning, Snapshot, TracerMonitor)}
def initialize(state):
for name, diagnostic in state.diagnostics.items():
diagnostic.initialize(state)
if diagnostic.sampling_frequency:
t, unit = time.format_time(diagnostic.sampling_frequency)
logger.info(f' Running diagnostic "{name}" every {t:.1f} {unit}')
if diagnostic.output_frequency:
t, unit = time.format_time(diagnostic.output_frequency)
logger.info(f' Writing output for diagnostic "{name}" every {t:.1f} {unit}')
def diagnose(state):
vs = state.variables
settings = state.settings
for diagnostic in state.diagnostics.values():
if diagnostic.sampling_frequency and vs.time % diagnostic.sampling_frequency < settings.dt_tracer:
diagnostic.diagnose(state)
def output(state):
vs = state.variables
settings = state.settings
for diagnostic in state.diagnostics.values():
if diagnostic.output_frequency and vs.time % diagnostic.output_frequency < settings.dt_tracer:
diagnostic.output(state)
import os
import copy
from veros.diagnostics.base import VerosDiagnostic
from veros.variables import TIMESTEPS, Variable
class Averages(VerosDiagnostic):
"""Time average output diagnostic.
All registered variables are summed up when :meth:`diagnose` is called,
and averaged and output upon calling :meth:`output`.
"""
name = "averages" #:
output_path = "{identifier}.averages.nc" #: File to write to. May contain format strings that are replaced with Veros attributes.
output_variables = None #: Iterable containing all variables to be averaged. Changes have no effect after ``initialize`` has been called.
output_frequency = None #: Frequency (in seconds) in which output is written.
sampling_frequency = None #: Frequency (in seconds) in which variables are accumulated.
def __init__(self, state):
self.var_meta = {
"average_nitts": Variable("average_nitts", None, write_to_restart=True),
}
self.output_variables = []
def initialize(self, state):
"""Register all variables to be averaged"""
for var in self.output_variables:
var_meta = copy.copy(state.var_meta[var])
var_meta.time_dependent = True
var_meta.write_to_restart = True
if self._has_timestep_dim(state, var):
var_meta.dims = var_meta.dims[:-1]
self.var_meta[var] = var_meta
self.initialize_variables(state)
self.initialize_output(state)
@staticmethod
def _has_timestep_dim(state, var):
if state.var_meta[var].dims is None:
return False
return state.var_meta[var].dims[-1] == TIMESTEPS[0]
def diagnose(self, state):
vs = state.variables
avg_vs = self.variables
avg_vs.average_nitts = avg_vs.average_nitts + 1
for key in self.output_variables:
var_data = getattr(avg_vs, key)
if self._has_timestep_dim(state, key):
setattr(avg_vs, key, var_data + getattr(vs, key)[..., vs.tau])
else:
setattr(avg_vs, key, var_data + getattr(vs, key))
def output(self, state):
"""Write averages to netcdf file and zero array"""
avg_vs = self.variables
if not os.path.isfile(self.get_output_file_name(state)):
self.initialize_output(state)
if avg_vs.average_nitts > 0:
for key in self.output_variables:
val = getattr(avg_vs, key)
setattr(avg_vs, key, val / avg_vs.average_nitts)
self.write_output(state)
for key in self.output_variables:
val = getattr(avg_vs, key)
setattr(avg_vs, key, 0 * val)
avg_vs.average_nitts = 0
import abc
import os
from veros.io_tools import netcdf as nctools
from veros.signals import do_not_disturb
from veros.state import VerosVariables
from veros import distributed, runtime_settings, time
class VerosDiagnostic(metaclass=abc.ABCMeta):
"""Base class for diagnostics. Provides an interface and wrappers for common I/O.
Any diagnostic needs to implement the 5 interface methods and set some attributes.
"""
name = None #: Name that identifies the current diagnostic
sampling_frequency = 0.0
output_frequency = 0.0
output_path = None
output_variables = None
var_meta = None #: Metadata of internal variables
extra_dimensions = None #: Dict of extra dimensions used in var_meta
def __init__(self, state):
pass
@abc.abstractmethod
def initialize(self, state):
"""Called at the end of setup. Use this to process user settings and handle setup."""
pass
@abc.abstractmethod
def diagnose(self, state):
"""Called with frequency ``sampling_frequency``."""
pass
@abc.abstractmethod
def output(self, state):
"""Called with frequency ``output_frequency``."""
pass
def initialize_variables(self, state):
if self.var_meta is None:
self.variables = None
return
dimensions = dict(state.dimensions)
if self.extra_dimensions is not None:
dimensions.update(self.extra_dimensions)
self.variables = VerosVariables(self.var_meta, dimensions)
# we leave diagnostic variables unlocked
self.variables.__locked__ = False
def get_output_file_name(self, state):
statedict = dict(state.variables.items())
statedict.update(state.settings.items())
return self.output_path.format(**statedict)
@do_not_disturb
def initialize_output(self, state):
inactive = not self.output_frequency and not self.sampling_frequency
no_output = not self.output_path or not self.output_variables
if runtime_settings.diskless_mode or inactive or no_output:
return
output_path = self.get_output_file_name(state)
if os.path.isfile(output_path) and not runtime_settings.force_overwrite:
raise IOError(
f'output file {output_path} for diagnostic "{self.name}" exists '
"(change output path or enable force_overwrite runtime setting)"
)
# possible race condition ahead!
distributed.barrier()
with nctools.threaded_io(output_path, "w") as outfile:
nctools.initialize_file(state, outfile, extra_dimensions=self.extra_dimensions)
for key in self.output_variables:
var = self.var_meta[key]
if key not in outfile.variables:
nctools.initialize_variable(state, key, var, outfile)
if not var.time_dependent:
var_data = self.variables.get(key)
nctools.write_variable(state, key, var, var_data, outfile)
@do_not_disturb
def write_output(self, state):
vs = state.variables
if runtime_settings.diskless_mode:
return
with nctools.threaded_io(self.get_output_file_name(state), "r+") as outfile:
current_days = time.convert_time(vs.time, "seconds", "days")
nctools.advance_time(current_days, outfile)
for key in self.output_variables:
var = self.var_meta[key]
var_data = self.variables.get(key)
nctools.write_variable(state, key, var, var_data, outfile)
from veros import logger
from veros.core.operators import numpy as npx
from veros.diagnostics.base import VerosDiagnostic
from veros.distributed import global_max
class CFLMonitor(VerosDiagnostic):
"""Diagnostic monitoring the maximum CFL number of the solution to detect
instabilities.
Writes output to stdout (no binary output).
"""
name = "cfl_monitor" #:
output_frequency = None #: Frequency (in seconds) in which output is written.
def initialize(self, state):
pass
def diagnose(self, state):
pass
def output(self, state):
"""
check for CFL violation
"""
vs = state.variables
settings = state.settings
cfl = global_max(
max(
npx.max(
npx.abs(vs.u[2:-2, 2:-2, :, vs.tau])
* vs.maskU[2:-2, 2:-2, :]
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
* settings.dt_tracer
),
npx.max(
npx.abs(vs.v[2:-2, 2:-2, :, vs.tau])
* vs.maskV[2:-2, 2:-2, :]
/ vs.dyt[npx.newaxis, 2:-2, npx.newaxis]
* settings.dt_tracer
),
)
)
wcfl = global_max(
npx.max(
npx.abs(vs.w[2:-2, 2:-2, :, vs.tau])
* vs.maskW[2:-2, 2:-2, :]
/ vs.dzt[npx.newaxis, npx.newaxis, :]
* settings.dt_tracer
)
)
if npx.isnan(cfl) or npx.isnan(wcfl):
raise RuntimeError(f"CFL number is NaN at iteration {vs.itt}")
logger.diagnostic(f" Maximal hor. CFL number = {cfl}")
logger.diagnostic(f" Maximal ver. CFL number = {wcfl}")
if settings.enable_eke or settings.enable_tke or settings.enable_idemix:
cfl = global_max(
max(
npx.max(
npx.abs(vs.u_wgrid[2:-2, 2:-2, :])
* vs.maskU[2:-2, 2:-2, :]
/ (vs.cost[npx.newaxis, 2:-2, npx.newaxis] * vs.dxt[2:-2, npx.newaxis, npx.newaxis])
* settings.dt_tracer
),
npx.max(
npx.abs(vs.v_wgrid[2:-2, 2:-2, :])
* vs.maskV[2:-2, 2:-2, :]
/ vs.dyt[npx.newaxis, 2:-2, npx.newaxis]
* settings.dt_tracer
),
)
)
wcfl = global_max(
npx.max(
npx.abs(vs.w_wgrid[2:-2, 2:-2, :])
* vs.maskW[2:-2, 2:-2, :]
/ vs.dzt[npx.newaxis, npx.newaxis, :]
* settings.dt_tracer
)
)
logger.diagnostic(f" Maximal hor. CFL number on w grid = {cfl}")
logger.diagnostic(f" Maximal ver. CFL number on w grid = {wcfl}")
import os
from veros import veros_kernel, KernelOutput, runtime_settings
from veros.core.operators import numpy as npx, update_multiply, at
from veros.diagnostics.base import VerosDiagnostic
from veros.variables import Variable
from veros.distributed import global_sum
ENERGY_VARIABLES = dict(
nitts=Variable("nitts", None, write_to_restart=True),
# mean energy content
k_m=Variable("Mean kinetic energy", None, "J", "Mean kinetic energy", write_to_restart=True),
Hd_m=Variable("Mean dynamic enthalpy", None, "J", "Mean dynamic enthalpy", write_to_restart=True),
eke_m=Variable("Meso-scale eddy energy", None, "J", "Meso-scale eddy energy", write_to_restart=True),
iw_m=Variable("Internal wave energy", None, "J", "Internal wave energy", write_to_restart=True),
tke_m=Variable("Turbulent kinetic energy", None, "J", "Turbulent kinetic energy", write_to_restart=True),
# energy changes
dE_tot_m=Variable("Change of total energy", None, "W", "Change of total energy", write_to_restart=True),
dk_m=Variable("Change of KE", None, "W", "Change of kinetic energy", write_to_restart=True),
dHd_m=Variable("Change of Hd", None, "W", "Change of dynamic enthalpy", write_to_restart=True),
deke_m=Variable("Change of EKE", None, "W", "Change of meso-scale eddy energy", write_to_restart=True),
diw_m=Variable("Change of E_iw", None, "W", "Change of internal wave energy", write_to_restart=True),
dtke_m=Variable("Change of TKE", None, "W", "Change of tubulent kinetic energy", write_to_restart=True),
# dissipation
ke_diss_m=Variable("Dissipation of KE", None, "W", "Dissipation of kinetic energy", write_to_restart=True),
Hd_diss_m=Variable("Dissipation of Hd", None, "W", "Dissipation of dynamic enthalpy", write_to_restart=True),
eke_diss_m=Variable(
"Dissipation of EKE", None, "W", "Dissipation of meso-scale eddy energy", write_to_restart=True
),
iw_diss_m=Variable("Dissipation of E_iw", None, "W", "Dissipation of internal wave energy", write_to_restart=True),
tke_diss_m=Variable(
"Dissipation of TKE", None, "W", "Dissipation of turbulent kinetic energy", write_to_restart=True
),
adv_diss_m=Variable("Dissipation by advection", None, "W", "Dissipation by advection", write_to_restart=True),
# external forcing
wind_m=Variable("Wind work", None, "W", "Wind work", write_to_restart=True),
dHd_sources_m=Variable(
"Hd production by ext. sources",
None,
"W",
"Dynamic enthalpy production through external sources",
write_to_restart=True,
),
iw_forc_m=Variable(
"External forcing of E_iw", None, "W", "External forcing of internal wave energy", write_to_restart=True
),
tke_forc_m=Variable(
"External forcing of TKE", None, "W", "External forcing of turbulent kinetic energy", write_to_restart=True
),
# exchange
ke_hd_m=Variable(
"Exchange KE -> Hd", None, "W", "Exchange between kinetic energy and dynamic enthalpy", write_to_restart=True
),
ke_tke_m=Variable(
"Exchange KE -> TKE by vert. friction",
None,
"W",
"Exchange between kinetic energy and turbulent kinetic energy by vertical friction",
write_to_restart=True,
),
ke_iw_m=Variable(
"Exchange KE -> IW by bottom friction",
None,
"W",
"Exchange between kinetic energy and internal wave energy by bottom friction",
write_to_restart=True,
),
tke_hd_m=Variable(
"Exchange TKE -> Hd by vertical mixing",
None,
"W",
"Exchange between turbulent kinetic energy and dynamic enthalpy by vertical mixing",
write_to_restart=True,
),
ke_eke_m=Variable(
"Exchange KE -> EKE by lateral friction",
None,
"W",
"Exchange between kinetic energy and eddy kinetic energy by lateral friction",
write_to_restart=True,
),
hd_eke_m=Variable(
"Exchange Hd -> EKE by GM and lateral mixing",
None,
"W",
"Exchange between dynamic enthalpy and eddy kinetic energy by GM and lateral mixing",
write_to_restart=True,
),
eke_tke_m=Variable(
"Exchange EKE -> TKE", None, "W", "Exchange between eddy and turbulent kinetic energy", write_to_restart=True
),
eke_iw_m=Variable(
"Exchange EKE -> IW",
None,
"W",
"Exchange between eddy kinetic energy and internal wave energy",
write_to_restart=True,
),
# cabbeling
cabb_m=Variable("Cabbeling by vertical mixing", None, "W", "Cabbeling by vertical mixing", write_to_restart=True),
cabb_iso_m=Variable(
"Cabbeling by isopycnal mixing", None, "W", "Cabbeling by isopycnal mixing", write_to_restart=True
),
)
DEFAULT_OUTPUT_VARS = [var for var in ENERGY_VARIABLES.keys() if var not in ("nitts",)]
class Energy(VerosDiagnostic):
"""Diagnose globally averaged energy cycle. Also averages energy in time."""
name = "energy" #:
output_path = "{identifier}.energy.nc" #: File to write to. May contain format strings that are replaced with Veros attributes.
output_frequency = None #: Frequency (in seconds) in which output is written.
sampling_frequency = None #: Frequency (in seconds) in which variables are accumulated.
var_meta = ENERGY_VARIABLES
def __init__(self, state):
self.output_variables = DEFAULT_OUTPUT_VARS.copy()
def initialize(self, state):
self.initialize_variables(state)
self.initialize_output(state)
def diagnose(self, state):
energies = diagnose_kernel(state)
# store results
for energy, val in energies._asdict().items():
total_val = self.variables.get(energy)
setattr(self.variables, energy, total_val + val)
self.variables.nitts = self.variables.nitts + 1
def output(self, state):
if not os.path.isfile(self.get_output_file_name(state)):
self.initialize_output(state)
energy_vs = self.variables
nitts = float(energy_vs.nitts or 1)
for key in self.output_variables:
val = getattr(energy_vs, key)
setattr(energy_vs, key, val * state.settings.rho_0 / nitts)
self.write_output(state)
for key in self.output_variables:
setattr(energy_vs, key, 0.0)
energy_vs.nitts = 0
@veros_kernel
def diagnose_kernel(state):
vs = state.variables
settings = state.settings
# changes of dynamic enthalpy
vol_t = vs.area_t[2:-2, 2:-2, npx.newaxis] * vs.dzt[npx.newaxis, npx.newaxis, :] * vs.maskT[2:-2, 2:-2, :]
dP_iso = global_sum(
npx.sum(
vol_t
* settings.grav
/ settings.rho_0
* (
-vs.int_drhodT[2:-2, 2:-2, :, vs.tau] * vs.dtemp_iso[2:-2, 2:-2, :]
- vs.int_drhodS[2:-2, 2:-2, :, vs.tau] * vs.dsalt_iso[2:-2, 2:-2, :]
)
)
)
dP_hmix = global_sum(
npx.sum(
vol_t
* settings.grav
/ settings.rho_0
* (
-vs.int_drhodT[2:-2, 2:-2, :, vs.tau] * vs.dtemp_hmix[2:-2, 2:-2, :]
- vs.int_drhodS[2:-2, 2:-2, :, vs.tau] * vs.dsalt_hmix[2:-2, 2:-2, :]
)
)
)
dP_vmix = global_sum(
npx.sum(
vol_t
* settings.grav
/ settings.rho_0
* (
-vs.int_drhodT[2:-2, 2:-2, :, vs.tau] * vs.dtemp_vmix[2:-2, 2:-2, :]
- vs.int_drhodS[2:-2, 2:-2, :, vs.tau] * vs.dsalt_vmix[2:-2, 2:-2, :]
)
)
)
dP_m = global_sum(
npx.sum(
vol_t
* settings.grav
/ settings.rho_0
* (
-vs.int_drhodT[2:-2, 2:-2, :, vs.tau] * vs.dtemp[2:-2, 2:-2, :, vs.tau]
- vs.int_drhodS[2:-2, 2:-2, :, vs.tau] * vs.dsalt[2:-2, 2:-2, :, vs.tau]
)
)
)
dP_m_all = dP_m + dP_vmix + dP_hmix + dP_iso
# changes of kinetic energy
vol_u = vs.area_u[2:-2, 2:-2, npx.newaxis] * vs.dzt[npx.newaxis, npx.newaxis, :]
vol_v = vs.area_v[2:-2, 2:-2, npx.newaxis] * vs.dzt[npx.newaxis, npx.newaxis, :]
k_m = global_sum(
npx.sum(
vol_t
* 0.5
* (
0.5 * (vs.u[2:-2, 2:-2, :, vs.tau] ** 2 + vs.u[1:-3, 2:-2, :, vs.tau] ** 2)
+ 0.5 * (vs.v[2:-2, 2:-2, :, vs.tau] ** 2)
+ vs.v[2:-2, 1:-3, :, vs.tau] ** 2
)
)
)
p_m = global_sum(npx.sum(vol_t * vs.Hd[2:-2, 2:-2, :, vs.tau]))
dk_m = global_sum(
npx.sum(
vs.u[2:-2, 2:-2, :, vs.tau] * vs.du[2:-2, 2:-2, :, vs.tau] * vol_u
+ vs.v[2:-2, 2:-2, :, vs.tau] * vs.dv[2:-2, 2:-2, :, vs.tau] * vol_v
+ vs.u[2:-2, 2:-2, :, vs.tau] * vs.du_mix[2:-2, 2:-2, :] * vol_u
+ vs.v[2:-2, 2:-2, :, vs.tau] * vs.dv_mix[2:-2, 2:-2, :] * vol_v
)
)
# K*Nsqr and KE and dyn. enthalpy dissipation
vol_w = vs.area_t[2:-2, 2:-2, npx.newaxis] * vs.dzw[npx.newaxis, npx.newaxis, :] * vs.maskW[2:-2, 2:-2, :]
vol_w = update_multiply(vol_w, at[:, :, -1], 0.5)
def mean_w(var):
return global_sum(npx.sum(var[2:-2, 2:-2, :] * vol_w))
mdiss_vmix = mean_w(vs.P_diss_v)
mdiss_nonlin = mean_w(vs.P_diss_nonlin)
mdiss_adv = mean_w(vs.P_diss_adv)
mdiss_hmix = mean_w(vs.P_diss_hmix)
mdiss_iso = mean_w(vs.P_diss_iso)
mdiss_skew = mean_w(vs.P_diss_skew)
mdiss_sources = mean_w(vs.P_diss_sources)
mdiss_h = mean_w(vs.K_diss_h)
mdiss_v = mean_w(vs.K_diss_v)
mdiss_gm = mean_w(vs.K_diss_gm)
mdiss_bot = mean_w(vs.K_diss_bot)
wrhom = global_sum(
npx.sum(
-vs.area_t[2:-2, 2:-2, npx.newaxis]
* vs.maskW[2:-2, 2:-2, :-1]
* (vs.p_hydro[2:-2, 2:-2, 1:] - vs.p_hydro[2:-2, 2:-2, :-1])
* vs.w[2:-2, 2:-2, :-1, vs.tau]
)
)
# wind work
if runtime_settings.pyom_compatibility_mode:
# surface_tau* has different units in PyOM
wind = global_sum(
npx.sum(
vs.u[2:-2, 2:-2, -1, vs.tau]
* vs.surface_taux[2:-2, 2:-2]
* vs.maskU[2:-2, 2:-2, -1]
* vs.area_u[2:-2, 2:-2]
+ vs.v[2:-2, 2:-2, -1, vs.tau]
* vs.surface_tauy[2:-2, 2:-2]
* vs.maskV[2:-2, 2:-2, -1]
* vs.area_v[2:-2, 2:-2]
)
)
else:
wind = global_sum(
npx.sum(
vs.u[2:-2, 2:-2, -1, vs.tau]
* vs.surface_taux[2:-2, 2:-2]
/ settings.rho_0
* vs.maskU[2:-2, 2:-2, -1]
* vs.area_u[2:-2, 2:-2]
+ vs.v[2:-2, 2:-2, -1, vs.tau]
* vs.surface_tauy[2:-2, 2:-2]
/ settings.rho_0
* vs.maskV[2:-2, 2:-2, -1]
* vs.area_v[2:-2, 2:-2]
)
)
# meso-scale energy
if settings.enable_eke:
eke_m = mean_w(vs.eke[..., vs.tau])
deke_m = global_sum(
npx.sum(vol_w * (vs.eke[2:-2, 2:-2, :, vs.taup1] - vs.eke[2:-2, 2:-2, :, vs.tau]) / settings.dt_tracer)
)
eke_diss = mean_w(vs.eke_diss_iw)
eke_diss_tke = mean_w(vs.eke_diss_tke)
else:
eke_m = deke_m = eke_diss_tke = 0.0
eke_diss = mdiss_gm + mdiss_h + mdiss_skew
if not settings.enable_store_cabbeling_heat:
eke_diss += -mdiss_hmix - mdiss_iso
# small-scale energy
if settings.enable_tke:
dt_tke = settings.dt_mom
tke_m = mean_w(vs.tke[..., vs.tau])
dtke_m = mean_w((vs.tke[..., vs.taup1] - vs.tke[..., vs.tau]) / dt_tke)
tke_diss = mean_w(vs.tke_diss)
tke_forc = global_sum(
npx.sum(
vs.area_t[2:-2, 2:-2]
* vs.maskW[2:-2, 2:-2, -1]
* (vs.forc_tke_surface[2:-2, 2:-2] + vs.tke_surf_corr[2:-2, 2:-2])
)
)
else:
tke_m = dtke_m = tke_diss = tke_forc = 0.0
# internal wave energy
if settings.enable_idemix:
iw_m = mean_w(vs.E_iw[..., vs.tau])
diw_m = global_sum(
npx.sum(vol_w * (vs.E_iw[2:-2, 2:-2, :, vs.taup1] - vs.E_iw[2:-2, 2:-2, :, vs.tau]) / vs.dt_tracer)
)
iw_diss = mean_w(vs.iw_diss)
k = npx.maximum(1, vs.kbot[2:-2, 2:-2]) - 1
mask = k[:, :, npx.newaxis] == npx.arange(settings.nz)[npx.newaxis, npx.newaxis, :]
iwforc = global_sum(
npx.sum(
vs.area_t[2:-2, 2:-2]
* (
vs.forc_iw_surface[2:-2, 2:-2] * vs.maskW[2:-2, 2:-2, -1]
+ npx.sum(mask * vs.forc_iw_bottom[2:-2, 2:-2, npx.newaxis] * vs.maskW[2:-2, 2:-2, :], axis=2)
)
)
)
else:
iw_m = diw_m = iwforc = 0.0
iw_diss = eke_diss
if settings.enable_store_bottom_friction_tke:
ke_tke_m = mdiss_v + mdiss_bot
ke_iw_m = 0.0
else:
ke_tke_m = mdiss_v
ke_iw_m = mdiss_bot
hd_eke_m = -mdiss_skew
tke_hd_m = -mdiss_vmix - mdiss_adv
if not settings.enable_store_cabbeling_heat:
hd_eke_m = hd_eke_m - mdiss_hmix - mdiss_iso
tke_hd_m = tke_hd_m - mdiss_nonlin
return KernelOutput(
k_m=k_m,
Hd_m=p_m,
eke_m=eke_m,
iw_m=iw_m,
tke_m=tke_m,
dk_m=dk_m,
dHd_m=dP_m_all + mdiss_sources,
deke_m=deke_m,
diw_m=diw_m,
dtke_m=dtke_m,
dE_tot_m=dk_m + dP_m_all + mdiss_sources + deke_m + diw_m + dtke_m,
wind_m=wind,
dHd_sources_m=mdiss_sources,
iw_forc_m=iwforc,
tke_forc_m=tke_forc,
ke_diss_m=mdiss_h + mdiss_v + mdiss_gm + mdiss_bot,
Hd_diss_m=mdiss_vmix + mdiss_nonlin + mdiss_hmix + mdiss_adv + mdiss_iso + mdiss_skew,
eke_diss_m=eke_diss + eke_diss_tke,
iw_diss_m=iw_diss,
tke_diss_m=tke_diss,
adv_diss_m=mdiss_adv,
ke_hd_m=wrhom,
ke_eke_m=mdiss_h + mdiss_gm,
hd_eke_m=-mdiss_skew,
ke_tke_m=ke_tke_m,
ke_iw_m=ke_iw_m,
tke_hd_m=tke_hd_m,
eke_tke_m=eke_diss_tke,
eke_iw_m=eke_diss,
cabb_m=mdiss_nonlin,
cabb_iso_m=mdiss_hmix + mdiss_iso,
)
import os
from veros import logger, veros_kernel, KernelOutput
from veros.diagnostics.base import VerosDiagnostic
from veros.core import density
from veros.variables import Variable, allocate
from veros.distributed import global_sum
from veros.core.operators import numpy as npx, update, update_add, at, for_loop
VARIABLES = {
"nitts": Variable("nitts", None, write_to_restart=True),
"sigma": Variable("Sigma axis", ("sigma",), "kg/m^3", "Sigma axis", time_dependent=False, write_to_restart=True),
"zarea": Variable(
"zarea",
("yu", "zt"),
write_to_restart=True,
),
"trans": Variable("Meridional transport", ("yu", "sigma"), "m^3/s", "Meridional transport", write_to_restart=True),
"vsf_iso": Variable("Meridional transport", ("yu", "zw"), "m^3/s", "Meridional transport", write_to_restart=True),
"vsf_depth": Variable("Meridional transport", ("yu", "zw"), "m^3/s", "Meridional transport", write_to_restart=True),
"bolus_iso": Variable(
"Meridional transport",
("yu", "zw"),
"m^3/s",
"Meridional transport",
write_to_restart=True,
active=lambda settings: settings.enable_neutral_diffusion and settings.enable_skew_diffusion,
),
"bolus_depth": Variable(
"Meridional transport",
("yu", "zw"),
"m^3/s",
"Meridional transport",
write_to_restart=True,
active=lambda settings: settings.enable_neutral_diffusion and settings.enable_skew_diffusion,
),
}
DEFAULT_OUTPUT_VARS = [var for var in VARIABLES.keys() if var not in ("nitts",)]
class Overturning(VerosDiagnostic):
"""Isopycnal overturning diagnostic. Computes and writes vertical streamfunctions
(zonally averaged).
"""
name = "overturning" #:
output_path = "{identifier}.overturning.nc" #: File to write to. May contain format strings that are replaced with Veros attributes.
output_frequency = None #: Frequency (in seconds) in which output is written.
sampling_frequency = None #: Frequency (in seconds) in which variables are accumulated.
p_ref = 2000.0 #: Reference pressure for isopycnals
var_meta = VARIABLES
def __init__(self, state):
self.output_variables = []
for var in DEFAULT_OUTPUT_VARS:
active = self.var_meta[var].active
if callable(active):
active = active(state.settings)
if active:
self.output_variables.append(var)
def initialize(self, state):
vs = state.variables
settings = state.settings
# sigma levels
nlevel = settings.nz * 4
sige = density.get_potential_rho(state, 35.0, -2.0, press_ref=self.p_ref)
sigs = density.get_potential_rho(state, 35.0, 30.0, press_ref=self.p_ref)
dsig = float(sige - sigs) / (nlevel - 1)
logger.debug(" Sigma ranges for overturning diagnostic:")
logger.debug(f" Start sigma0 = {sigs:.1f}")
logger.debug(f" End sigma0 = {sige:.1f}")
logger.debug(f" Delta sigma0 = {dsig:.1e}")
if settings.enable_neutral_diffusion and settings.enable_skew_diffusion:
logger.debug(" Also calculating overturning by eddy-driven velocities")
self.extra_dimensions = dict(sigma=nlevel)
self.initialize_variables(state)
ovt_vs = self.variables
ovt_vs.sigma = sigs + dsig * npx.arange(nlevel)
# precalculate area below z levels
ovt_vs.zarea = update(
ovt_vs.zarea,
at[2:-2, :],
npx.cumsum(
zonal_sum(
npx.sum(
vs.dxt[2:-2, npx.newaxis, npx.newaxis]
* vs.cosu[npx.newaxis, 2:-2, npx.newaxis]
* vs.maskV[2:-2, 2:-2, :],
axis=0,
)
)
* vs.dzt[npx.newaxis, :],
axis=1,
),
)
self.initialize_output(state)
def diagnose(self, state):
ovt_vs = self.variables
ovt_vs.update(diagnose_kernel(state, ovt_vs, self.p_ref))
ovt_vs.nitts = ovt_vs.nitts + 1
def output(self, state):
if not os.path.isfile(self.get_output_file_name(state)):
self.initialize_output(state)
ovt_vs = self.variables
mean_variables = ("trans", "vsf_iso", "vsf_depth")
if ovt_vs.nitts > 0:
for var in mean_variables:
if var not in self.output_variables:
continue
val = getattr(ovt_vs, var)
setattr(ovt_vs, var, val / ovt_vs.nitts)
self.write_output(state)
for var in mean_variables:
if var not in self.output_variables:
continue
val = getattr(ovt_vs, var)
setattr(ovt_vs, var, val * 0)
ovt_vs.nitts = 0
@veros_kernel
def _interpolate_depth_coords(coords, arr, interp_coords):
# ensure depth coordinates are monotonically increasing
coords = -coords
interp_coords = -interp_coords
interp_vectorized = npx.vectorize(npx.interp, signature="(n),(m),(m)->(n)")
return interp_vectorized(interp_coords, coords, arr)
@veros_kernel
def diagnose_kernel(state, ovt_vs, p_ref):
vs = state.variables
settings = state.settings
nlevel = settings.nz * 4
# sigma at p_ref
sig_loc = allocate(state.dimensions, ("xt", "yt", "zt"))
sig_loc = update(
sig_loc,
at[2:-2, 2:-1, :],
density.get_rho(state, vs.salt[2:-2, 2:-1, :, vs.tau], vs.temp[2:-2, 2:-1, :, vs.tau], p_ref),
)
# transports below isopycnals and area below isopycnals
sig_loc_face = 0.5 * (sig_loc[2:-2, 2:-2, :] + sig_loc[2:-2, 3:-1, :])
trans = allocate(state.dimensions, ("yu", nlevel))
z_sig = allocate(state.dimensions, ("yu", nlevel))
fac = (
vs.dxt[2:-2, npx.newaxis, npx.newaxis]
* vs.cosu[npx.newaxis, 2:-2, npx.newaxis]
* vs.dzt[npx.newaxis, npx.newaxis, :]
* vs.maskV[2:-2, 2:-2, :]
)
def loop_body(m, values):
trans, z_sig = values
mask = sig_loc_face > ovt_vs.sigma[m]
trans = update(trans, at[2:-2, m], npx.sum(vs.v[2:-2, 2:-2, :, vs.tau] * fac * mask, axis=(0, 2)))
z_sig = update(z_sig, at[2:-2, m], npx.sum(fac * mask, axis=(0, 2)))
return (trans, z_sig)
trans, z_sig = for_loop(0, nlevel, loop_body, init_val=(trans, z_sig))
trans = zonal_sum(trans)
z_sig = zonal_sum(z_sig)
ovt_vs.trans = ovt_vs.trans + trans
if settings.enable_neutral_diffusion and settings.enable_skew_diffusion:
# eddy-driven transports below isopycnals
bolus_trans = allocate(state.dimensions, ("yu", nlevel))
def loop_body(m, bolus_trans):
mask = sig_loc_face > ovt_vs.sigma[m]
bolus_trans = update(
bolus_trans,
at[2:-2, m],
npx.sum(
npx.sum(
(vs.B1_gm[2:-2, 2:-2, 1:] - vs.B1_gm[2:-2, 2:-2, :-1])
* vs.dxt[2:-2, npx.newaxis, npx.newaxis]
* vs.cosu[npx.newaxis, 2:-2, npx.newaxis]
* vs.maskV[2:-2, 2:-2, 1:]
* mask[:, :, 1:],
axis=2,
)
+ vs.B1_gm[2:-2, 2:-2, 0]
* vs.dxt[2:-2, npx.newaxis]
* vs.cosu[npx.newaxis, 2:-2]
* vs.maskV[2:-2, 2:-2, 0]
* mask[:, :, 0],
axis=0,
),
)
return bolus_trans
bolus_trans = for_loop(0, nlevel, loop_body, init_val=bolus_trans)
bolus_trans = zonal_sum(bolus_trans)
# streamfunction on geopotentials
ovt_vs.vsf_depth = update_add(
ovt_vs.vsf_depth,
at[2:-2, :],
npx.cumsum(
zonal_sum(
npx.sum(
vs.dxt[2:-2, npx.newaxis, npx.newaxis]
* vs.cosu[npx.newaxis, 2:-2, npx.newaxis]
* vs.v[2:-2, 2:-2, :, vs.tau]
* vs.maskV[2:-2, 2:-2, :],
axis=0,
)
)
* vs.dzt[npx.newaxis, :],
axis=1,
),
)
if settings.enable_neutral_diffusion and settings.enable_skew_diffusion:
# streamfunction for eddy driven velocity on geopotentials
ovt_vs.bolus_depth = update_add(
ovt_vs.bolus_depth,
at[2:-2, :],
zonal_sum(
npx.sum(
vs.dxt[2:-2, npx.newaxis, npx.newaxis]
* vs.cosu[npx.newaxis, 2:-2, npx.newaxis]
* vs.B1_gm[2:-2, 2:-2, :],
axis=0,
)
),
)
# interpolate from isopycnals to depth
ovt_vs.vsf_iso = update_add(
ovt_vs.vsf_iso, at[2:-2, :], _interpolate_depth_coords(z_sig[2:-2, :], trans[2:-2, :], ovt_vs.zarea[2:-2, :])
)
if settings.enable_neutral_diffusion and settings.enable_skew_diffusion:
ovt_vs.bolus_iso = update_add(
ovt_vs.bolus_iso,
at[2:-2, :],
_interpolate_depth_coords(z_sig[2:-2, :], bolus_trans[2:-2, :], ovt_vs.zarea[2:-2, :]),
)
return KernelOutput(
trans=ovt_vs.trans,
vsf_depth=ovt_vs.vsf_depth,
vsf_iso=ovt_vs.vsf_iso,
bolus_iso=ovt_vs.bolus_iso,
bolus_depth=ovt_vs.bolus_depth,
)
def zonal_sum(arr):
return global_sum(arr, axis=0)
import os
import copy
from veros import time, logger
from veros.diagnostics.base import VerosDiagnostic
DEFAULT_OUTPUT_VARS = [
"dxt",
"dxu",
"dyt",
"dyu",
"zt",
"zw",
"dzt",
"dzw",
"ht",
"hu",
"hv",
"beta",
"area_t",
"area_u",
"area_v",
"rho",
"prho",
"int_drhodT",
"int_drhodS",
"Nsqr",
"Hd",
"temp",
"salt",
"forc_temp_surface",
"forc_salt_surface",
"u",
"v",
"w",
"p_hydro",
"kappaM",
"kappaH",
"surface_taux",
"surface_tauy",
"forc_rho_surface",
"psi",
"isle",
"psin",
"xt",
"xu",
"yt",
"yu",
"temp_source",
"salt_source",
"u_source",
"v_source",
"tke",
"forc_tke_surface",
"eke",
"E_iw",
"forc_iw_surface",
"forc_iw_bottom",
]
class Snapshot(VerosDiagnostic):
"""Writes snapshots of the current solution. Also reads and writes the main restart
data required for restarting a Veros simulation.
"""
output_path = "{identifier}.snapshot.nc"
"""File to write to. May contain format strings that are replaced with Veros attributes."""
name = "snapshot" #:
output_frequency = None #: Frequency (in seconds) in which output is written.
def __init__(self, state):
self.output_variables = []
for var in DEFAULT_OUTPUT_VARS:
active = state.var_meta[var].active
if callable(active):
active = active(state.settings)
if active:
self.output_variables.append(var)
def initialize(self, state):
vs = state.variables
self.var_meta = {var: copy.copy(state.var_meta[var]) for var in self.output_variables}
for var in self.var_meta.values():
var.write_to_restart = False
self.variables = vs
self.initialize_output(state)
def diagnose(self, state):
pass
def output(self, state):
vs = state.variables
time_length, time_unit = time.format_time(vs.time)
logger.info(f" Writing snapshot at {time_length:.2f} {time_unit}")
if not os.path.isfile(self.get_output_file_name(state)):
self.initialize_output(state)
self.write_output(state)
from veros import logger
from veros.variables import Variable
from veros.core.operators import numpy as npx
from veros.diagnostics.base import VerosDiagnostic
from veros.distributed import global_sum
class TracerMonitor(VerosDiagnostic):
"""Diagnostic monitoring global tracer contents / fluxes.
Writes output to stdout (no binary output).
"""
name = "tracer_monitor"
output_frequency = None
def __init__(self, state):
self.var_meta = {
"tempm1": Variable("tempm1", None, write_to_restart=True),
"vtemp1": Variable("vtemp1", None, write_to_restart=True),
"saltm1": Variable("saltm1", None, write_to_restart=True),
"vsalt1": Variable("vsalt1", None, write_to_restart=True),
}
def initialize(self, state):
self.initialize_variables(state)
def diagnose(self, state):
pass
def output(self, state):
"""
Diagnose tracer content
"""
vs = state.variables
tracer_vs = self.variables
cell_volume = vs.area_t[2:-2, 2:-2, npx.newaxis] * vs.dzt[npx.newaxis, npx.newaxis, :] * vs.maskT[2:-2, 2:-2, :]
volm = global_sum(npx.sum(cell_volume))
tempm = global_sum(npx.sum(cell_volume * vs.temp[2:-2, 2:-2, :, vs.tau]))
saltm = global_sum(npx.sum(cell_volume * vs.salt[2:-2, 2:-2, :, vs.tau]))
vtemp = global_sum(npx.sum(cell_volume * vs.temp[2:-2, 2:-2, :, vs.tau] ** 2))
vsalt = global_sum(npx.sum(cell_volume * vs.salt[2:-2, 2:-2, :, vs.tau] ** 2))
logger.diagnostic(
f" Mean temperature {tempm / volm:.2e} change to last {(tempm - tracer_vs.tempm1) / volm:.2e}"
)
logger.diagnostic(
f" Mean salinity {saltm / volm:.2e} change to last {(saltm - tracer_vs.saltm1) / volm:.2e}"
)
logger.diagnostic(
f" Temperature var. {vtemp / volm:.2e} change to last {(vtemp - tracer_vs.vtemp1) / volm:.2e}"
)
logger.diagnostic(
f" Salinity var. {vsalt / volm:.2e} change to last {(vsalt - tracer_vs.vsalt1) / volm:.2e}"
)
tracer_vs.tempm1 = tempm
tracer_vs.vtemp1 = vtemp
tracer_vs.saltm1 = saltm
tracer_vs.vsalt1 = vsalt
import functools
from veros import runtime_settings as rs, runtime_state as rst
from veros.routines import CURRENT_CONTEXT
SCATTERED_DIMENSIONS = (("xt", "xu"), ("yt", "yu"))
def dist_context_only(function=None, *, noop_return_arg=None):
def decorator(function):
@functools.wraps(function)
def dist_context_only_wrapper(*args, **kwargs):
if rst.proc_num == 1 or not CURRENT_CONTEXT.is_dist_safe:
# no-op for sequential execution
if noop_return_arg is None:
return None
# return input array unchanged
return args[noop_return_arg]
return function(*args, **kwargs)
return dist_context_only_wrapper
if function is not None:
return decorator(function)
return decorator
def send(buf, dest, comm, tag=None):
kwargs = {}
if tag is not None:
kwargs.update(tag=tag)
if rs.backend == "jax":
from mpi4jax import send
token = CURRENT_CONTEXT.mpi4jax_token
new_token = send(buf, dest=dest, comm=comm, token=token, **kwargs)
CURRENT_CONTEXT.mpi4jax_token = new_token
else:
comm.Send(ascontiguousarray(buf), dest=dest, **kwargs)
def recv(buf, source, comm, tag=None):
kwargs = {}
if tag is not None:
kwargs.update(tag=tag)
if rs.backend == "jax":
from mpi4jax import recv
token = CURRENT_CONTEXT.mpi4jax_token
buf, new_token = recv(buf, source=source, comm=comm, token=token, **kwargs)
CURRENT_CONTEXT.mpi4jax_token = new_token
return buf
buf = buf.copy()
comm.Recv(buf, source=source, **kwargs)
return buf
def sendrecv(sendbuf, recvbuf, source, dest, comm, sendtag=None, recvtag=None):
kwargs = {}
if sendtag is not None:
kwargs.update(sendtag=sendtag)
if recvtag is not None:
kwargs.update(recvtag=recvtag)
if rs.backend == "jax":
from mpi4jax import sendrecv
token = CURRENT_CONTEXT.mpi4jax_token
recvbuf, new_token = sendrecv(sendbuf, recvbuf, source=source, dest=dest, comm=comm, token=token, **kwargs)
CURRENT_CONTEXT.mpi4jax_token = new_token
return recvbuf
recvbuf = recvbuf.copy()
comm.Sendrecv(sendbuf=ascontiguousarray(sendbuf), recvbuf=recvbuf, source=source, dest=dest, **kwargs)
return recvbuf
def bcast(buf, comm, root=0):
if rs.backend == "jax":
from mpi4jax import bcast
token = CURRENT_CONTEXT.mpi4jax_token
buf, new_token = bcast(buf, root=root, comm=comm, token=token)
CURRENT_CONTEXT.mpi4jax_token = new_token
return buf
return comm.bcast(buf, root=root)
def allreduce(buf, op, comm):
if rs.backend == "jax":
from mpi4jax import allreduce
token = CURRENT_CONTEXT.mpi4jax_token
buf, new_token = allreduce(buf, op=op, comm=comm, token=token)
CURRENT_CONTEXT.mpi4jax_token = new_token
return buf
from veros.core.operators import numpy as npx
recvbuf = npx.empty_like(buf)
comm.Allreduce(ascontiguousarray(buf), recvbuf, op=op)
return recvbuf
def ascontiguousarray(arr):
assert rs.backend == "numpy"
import numpy
return numpy.ascontiguousarray(arr)
def validate_decomposition(dimensions):
nx, ny = dimensions["xt"], dimensions["yt"]
if rs.mpi_comm is None:
if rs.num_proc[0] > 1 or rs.num_proc[1] > 1:
raise RuntimeError("mpi4py is required for distributed execution")
return
comm_size = rs.mpi_comm.Get_size()
proc_num = rs.num_proc[0] * rs.num_proc[1]
if proc_num != comm_size:
raise RuntimeError(f"number of processes ({proc_num}) does not match size of communicator ({comm_size})")
if nx % rs.num_proc[0]:
raise ValueError("processes do not divide domain evenly in x-direction")
if ny % rs.num_proc[1]:
raise ValueError("processes do not divide domain evenly in y-direction")
def get_chunk_size(nx, ny):
return (nx // rs.num_proc[0], ny // rs.num_proc[1])
def proc_rank_to_index(rank):
return (rank % rs.num_proc[0], rank // rs.num_proc[0])
def proc_index_to_rank(ix, iy):
return ix + iy * rs.num_proc[0]
def get_chunk_slices(nx, ny, dim_grid, proc_idx=None, include_overlap=False):
if not dim_grid:
return Ellipsis, Ellipsis
if proc_idx is None:
proc_idx = proc_rank_to_index(rst.proc_rank)
px, py = proc_idx
nxl, nyl = get_chunk_size(nx, ny)
if include_overlap:
sxl = 0 if px == 0 else 2
sxu = nxl + 4 if (px + 1) == rs.num_proc[0] else nxl + 2
syl = 0 if py == 0 else 2
syu = nyl + 4 if (py + 1) == rs.num_proc[1] else nyl + 2
else:
sxl = syl = 0
sxu = nxl
syu = nyl
global_slice, local_slice = [], []
for dim in dim_grid:
if dim in SCATTERED_DIMENSIONS[0]:
global_slice.append(slice(sxl + px * nxl, sxu + px * nxl))
local_slice.append(slice(sxl, sxu))
elif dim in SCATTERED_DIMENSIONS[1]:
global_slice.append(slice(syl + py * nyl, syu + py * nyl))
local_slice.append(slice(syl, syu))
else:
global_slice.append(slice(None))
local_slice.append(slice(None))
return tuple(global_slice), tuple(local_slice)
def get_process_neighbors(cyclic=False):
this_x, this_y = proc_rank_to_index(rst.proc_rank)
if this_x == 0:
if cyclic:
west = rs.num_proc[0] - 1
else:
west = None
else:
west = this_x - 1
if this_x == rs.num_proc[0] - 1:
if cyclic:
east = 0
else:
east = None
else:
east = this_x + 1
south = this_y - 1 if this_y != 0 else None
north = this_y + 1 if this_y != (rs.num_proc[1] - 1) else None
neighbors = dict(
# direct neighbors
west=(west, this_y),
south=(this_x, south),
east=(east, this_y),
north=(this_x, north),
# corners
southwest=(west, south),
southeast=(east, south),
northeast=(east, north),
northwest=(west, north),
)
global_neighbors = {k: proc_index_to_rank(*i) if None not in i else None for k, i in neighbors.items()}
return global_neighbors
@dist_context_only(noop_return_arg=0)
def exchange_overlap(arr, var_grid, cyclic):
from veros.core.operators import numpy as npx, update, at
# start west, go clockwise
send_order = (
"west",
"northwest",
"north",
"northeast",
"east",
"southeast",
"south",
"southwest",
)
# start east, go clockwise
recv_order = (
"east",
"southeast",
"south",
"southwest",
"west",
"northwest",
"north",
"northeast",
)
if len(var_grid) < 2:
d1, d2 = var_grid[0], None
else:
d1, d2 = var_grid[:2]
if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]:
# neither x nor y dependent, nothing to do
return arr
proc_neighbors = get_process_neighbors(cyclic)
if d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]:
overlap_slices_from = dict(
west=(slice(2, 4), slice(0, None), Ellipsis),
south=(slice(0, None), slice(2, 4), Ellipsis),
east=(slice(-4, -2), slice(0, None), Ellipsis),
north=(slice(0, None), slice(-4, -2), Ellipsis),
southwest=(slice(2, 4), slice(2, 4), Ellipsis),
southeast=(slice(-4, -2), slice(2, 4), Ellipsis),
northeast=(slice(-4, -2), slice(-4, -2), Ellipsis),
northwest=(slice(2, 4), slice(-4, -2), Ellipsis),
)
overlap_slices_to = dict(
west=(slice(0, 2), slice(0, None), Ellipsis),
south=(slice(0, None), slice(0, 2), Ellipsis),
east=(slice(-2, None), slice(0, None), Ellipsis),
north=(slice(0, None), slice(-2, None), Ellipsis),
southwest=(slice(0, 2), slice(0, 2), Ellipsis),
southeast=(slice(-2, None), slice(0, 2), Ellipsis),
northeast=(slice(-2, None), slice(-2, None), Ellipsis),
northwest=(slice(0, 2), slice(-2, None), Ellipsis),
)
else:
if d1 in SCATTERED_DIMENSIONS[0]:
send_order = ("west", "east")
recv_order = ("east", "west")
elif d1 in SCATTERED_DIMENSIONS[1]:
send_order = ("north", "south")
recv_order = ("south", "north")
else:
raise NotImplementedError()
overlap_slices_from = dict(
west=(slice(2, 4), Ellipsis),
south=(slice(2, 4), Ellipsis),
east=(slice(-4, -2), Ellipsis),
north=(slice(-4, -2), Ellipsis),
)
overlap_slices_to = dict(
west=(slice(0, 2), Ellipsis),
south=(slice(0, 2), Ellipsis),
east=(slice(-2, None), Ellipsis),
north=(slice(-2, None), Ellipsis),
)
for send_dir, recv_dir in zip(send_order, recv_order):
send_proc = proc_neighbors[send_dir]
recv_proc = proc_neighbors[recv_dir]
if send_proc is None and recv_proc is None:
continue
recv_idx = overlap_slices_to[recv_dir]
recv_arr = npx.empty_like(arr[recv_idx])
send_idx = overlap_slices_from[send_dir]
send_arr = arr[send_idx]
if send_proc is None:
recv_arr = recv(recv_arr, recv_proc, rs.mpi_comm)
arr = update(arr, at[recv_idx], recv_arr)
elif recv_proc is None:
send(send_arr, send_proc, rs.mpi_comm)
else:
recv_arr = sendrecv(send_arr, recv_arr, source=recv_proc, dest=send_proc, comm=rs.mpi_comm)
arr = update(arr, at[recv_idx], recv_arr)
return arr
def _memoize(function):
cached = {}
@functools.wraps(function)
def memoized(*args):
from mpi4py import MPI
# MPI Comms are not hashable, so we use the underlying handle instead
cache_args = tuple(MPI._handleof(arg) if isinstance(arg, MPI.Comm) else arg for arg in args)
if cache_args not in cached:
cached[cache_args] = function(*args)
return cached[cache_args]
return memoized
@_memoize
def _mpi_comm_along_axis(comm, procs, rank):
return comm.Split(procs, rank)
@dist_context_only(noop_return_arg=0)
def _reduce(arr, op, axis=None):
from veros.core.operators import numpy as npx
if axis is None:
comm = rs.mpi_comm
else:
assert axis in (0, 1)
pi = proc_rank_to_index(rst.proc_rank)
other_axis = 1 - axis
comm = _mpi_comm_along_axis(rs.mpi_comm, pi[other_axis], rst.proc_rank)
if npx.isscalar(arr):
squeeze = True
arr = npx.array([arr])
else:
squeeze = False
res = allreduce(arr, op=op, comm=comm)
if squeeze:
res = res[0]
return res
@dist_context_only(noop_return_arg=0)
def global_and(arr, axis=None):
from mpi4py import MPI
return _reduce(arr, MPI.LAND, axis=axis)
@dist_context_only(noop_return_arg=0)
def global_or(arr, axis=None):
from mpi4py import MPI
return _reduce(arr, MPI.LOR, axis=axis)
@dist_context_only(noop_return_arg=0)
def global_max(arr, axis=None):
from mpi4py import MPI
return _reduce(arr, MPI.MAX, axis=axis)
@dist_context_only(noop_return_arg=0)
def global_min(arr, axis=None):
from mpi4py import MPI
return _reduce(arr, MPI.MIN, axis=axis)
@dist_context_only(noop_return_arg=0)
def global_sum(arr, axis=None):
from mpi4py import MPI
return _reduce(arr, MPI.SUM, axis=axis)
@dist_context_only(noop_return_arg=2)
def _gather_1d(nx, ny, arr, dim):
from veros.core.operators import numpy as npx, update, at
assert dim in (0, 1)
otherdim = 1 - dim
pi = proc_rank_to_index(rst.proc_rank)
if pi[otherdim] != 0:
return arr
dim_grid = ["xt" if dim == 0 else "yt"] + [None] * (arr.ndim - 1)
gidx, idx = get_chunk_slices(nx, ny, dim_grid, include_overlap=True)
sendbuf = arr[idx]
if rst.proc_rank == 0:
buffer_list = []
for proc in range(1, rst.proc_num):
pi = proc_rank_to_index(proc)
if pi[otherdim] != 0:
continue
idx_g, idx_l = get_chunk_slices(nx, ny, dim_grid, include_overlap=True, proc_idx=pi)
recvbuf = npx.empty_like(arr[idx_l])
recvbuf = recv(recvbuf, source=proc, tag=20, comm=rs.mpi_comm)
buffer_list.append((idx_g, recvbuf))
out_shape = ((nx + 4, ny + 4)[dim],) + arr.shape[1:]
out = npx.empty(out_shape, dtype=arr.dtype)
out = update(out, at[gidx], sendbuf)
for idx, val in buffer_list:
out = update(out, at[idx], val)
return out
else:
send(sendbuf, dest=0, tag=20, comm=rs.mpi_comm)
return arr
@dist_context_only(noop_return_arg=2)
def _gather_xy(nx, ny, arr):
from veros.core.operators import numpy as npx, update, at
nxi, nyi = get_chunk_size(nx, ny)
assert arr.shape[:2] == (nxi + 4, nyi + 4), arr.shape
dim_grid = ["xt", "yt"] + [None] * (arr.ndim - 2)
gidx, idx = get_chunk_slices(nx, ny, dim_grid, include_overlap=True)
sendbuf = arr[idx]
if rst.proc_rank == 0:
buffer_list = []
for proc in range(1, rst.proc_num):
idx_g, idx_l = get_chunk_slices(nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc))
recvbuf = npx.empty_like(arr[idx_l])
recvbuf = recv(recvbuf, source=proc, tag=30, comm=rs.mpi_comm)
buffer_list.append((idx_g, recvbuf))
out_shape = (nx + 4, ny + 4) + arr.shape[2:]
out = npx.empty(out_shape, dtype=arr.dtype)
out = update(out, at[gidx], sendbuf)
for idx, val in buffer_list:
out = update(out, at[idx], val)
return out
send(sendbuf, dest=0, tag=30, comm=rs.mpi_comm)
return arr
@dist_context_only(noop_return_arg=0)
def gather(arr, dimensions, var_grid):
nx, ny = dimensions["xt"], dimensions["yt"]
if var_grid is None:
return arr
if len(var_grid) < 2:
d1, d2 = var_grid[0], None
else:
d1, d2 = var_grid[:2]
if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]:
# neither x nor y dependent, nothing to do
return arr
if d1 in SCATTERED_DIMENSIONS[0] and d2 not in SCATTERED_DIMENSIONS[1]:
# only x dependent
return _gather_1d(nx, ny, arr, 0)
elif d1 in SCATTERED_DIMENSIONS[1]:
# only y dependent
return _gather_1d(nx, ny, arr, 1)
elif d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]:
# x and y dependent
return _gather_xy(nx, ny, arr)
else:
raise NotImplementedError()
@dist_context_only(noop_return_arg=0)
def _scatter_constant(arr):
return bcast(arr, rs.mpi_comm, root=0)
@dist_context_only(noop_return_arg=2)
def _scatter_1d(nx, ny, arr, dim):
from veros.core.operators import numpy as npx, update, at
assert dim in (0, 1)
out_nx = get_chunk_size(nx, ny)[dim]
dim_grid = ["xt" if dim == 0 else "yt"] + [None] * (arr.ndim - 1)
_, local_slice = get_chunk_slices(nx, ny, dim_grid, include_overlap=True)
if rst.proc_rank == 0:
recvbuf = arr[local_slice]
for proc in range(1, rst.proc_num):
global_slice, _ = get_chunk_slices(
nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc)
)
sendbuf = arr[global_slice]
send(sendbuf, dest=proc, tag=40, comm=rs.mpi_comm)
# arr changes shape in main process
arr = npx.zeros((out_nx + 4,) + arr.shape[1:], dtype=arr.dtype)
else:
recvbuf = recv(arr[local_slice], source=0, tag=40, comm=rs.mpi_comm)
arr = update(arr, at[local_slice], recvbuf)
arr = exchange_overlap(arr, ["xt" if dim == 0 else "yt"], cyclic=False)
return arr
@dist_context_only(noop_return_arg=2)
def _scatter_xy(nx, ny, arr):
from veros.core.operators import numpy as npx, update, at
nxi, nyi = get_chunk_size(nx, ny)
dim_grid = ["xt", "yt"] + [None] * (arr.ndim - 2)
_, local_slice = get_chunk_slices(nx, ny, dim_grid, include_overlap=True)
if rst.proc_rank == 0:
recvbuf = arr[local_slice]
for proc in range(1, rst.proc_num):
global_slice, _ = get_chunk_slices(
nx, ny, dim_grid, include_overlap=True, proc_idx=proc_rank_to_index(proc)
)
sendbuf = arr[global_slice]
send(sendbuf, dest=proc, tag=50, comm=rs.mpi_comm)
# arr changes shape in main process
arr = npx.empty((nxi + 4, nyi + 4) + arr.shape[2:], dtype=arr.dtype)
else:
recvbuf = npx.empty_like(arr[local_slice])
recvbuf = recv(recvbuf, source=0, tag=50, comm=rs.mpi_comm)
arr = update(arr, at[local_slice], recvbuf)
arr = exchange_overlap(arr, ["xt", "yt"], cyclic=False)
return arr
@dist_context_only(noop_return_arg=0)
def scatter(arr, dimensions, var_grid):
from veros.core.operators import numpy as npx
if var_grid is None:
return _scatter_constant(arr)
nx, ny = dimensions["xt"], dimensions["yt"]
if len(var_grid) < 2:
d1, d2 = var_grid[0], None
else:
d1, d2 = var_grid[:2]
arr = npx.asarray(arr)
if d1 not in SCATTERED_DIMENSIONS[0] and d1 not in SCATTERED_DIMENSIONS[1] and d2 not in SCATTERED_DIMENSIONS[1]:
# neither x nor y dependent
return _scatter_constant(arr)
if d1 in SCATTERED_DIMENSIONS[0] and d2 not in SCATTERED_DIMENSIONS[1]:
# only x dependent
return _scatter_1d(nx, ny, arr, 0)
elif d1 in SCATTERED_DIMENSIONS[1]:
# only y dependent
return _scatter_1d(nx, ny, arr, 1)
elif d1 in SCATTERED_DIMENSIONS[0] and d2 in SCATTERED_DIMENSIONS[1]:
# x and y dependent
return _scatter_xy(nx, ny, arr)
else:
raise NotImplementedError("unreachable")
@dist_context_only
def barrier():
rs.mpi_comm.barrier()
@dist_context_only
def abort():
rs.mpi_comm.Abort()
import threading
import contextlib
from veros import logger, runtime_settings, runtime_state
@contextlib.contextmanager
def threaded_io(filepath, mode):
"""
If using IO threads, start a new thread to write the HDF5 data to disk.
"""
import h5py
if runtime_settings.use_io_threads:
_wait_for_disk(filepath)
_io_locks[filepath].clear()
kwargs = {}
if runtime_state.proc_num > 1:
kwargs.update(driver="mpio", comm=runtime_settings.mpi_comm)
h5file = h5py.File(filepath, mode, **kwargs)
try:
yield h5file
finally:
if runtime_settings.use_io_threads:
threading.Thread(target=_write_to_disk, args=(h5file, filepath)).start()
else:
_write_to_disk(h5file, filepath)
_io_locks = {}
def _add_to_locks(file_id):
"""
If there is no lock for file_id, create one
"""
if file_id not in _io_locks:
_io_locks[file_id] = threading.Event()
_io_locks[file_id].set()
def _wait_for_disk(file_id):
"""
Wait for the lock of file_id to be released
"""
logger.debug(f"Waiting for lock {file_id} to be released")
_add_to_locks(file_id)
lock_released = _io_locks[file_id].wait(runtime_settings.io_timeout)
if not lock_released:
raise RuntimeError("Timeout while waiting for disk IO to finish")
def _write_to_disk(h5file, file_id):
"""
Sync HDF5 data to disk, close file handle, and release lock.
May run in a separate thread.
"""
try:
h5file.close()
finally:
if runtime_settings.use_io_threads and file_id is not None:
_io_locks[file_id].set()
import json
import datetime
import threading
import contextlib
import numpy as np
from veros import (
logger,
variables,
distributed,
runtime_state,
runtime_settings as rs,
__version__ as veros_version,
)
"""
netCDF output is designed to follow the COARDS guidelines from
http://ferret.pmel.noaa.gov/Ferret/documentation/coards-netcdf-conventions
"""
def _get_setup_code(pyfile):
try:
with open(pyfile, "r") as f:
return f.read()
except FileNotFoundError:
return "UNKNOWN"
def initialize_file(state, ncfile, extra_dimensions=None, create_time_dimension=True):
"""
Define standard grid in netcdf file
"""
import h5netcdf
if not isinstance(ncfile, h5netcdf.File):
raise TypeError("Argument needs to be a netCDF4 Dataset")
if rs.setup_file is None:
setup_file = "UNKNOWN"
setup_code = "UNKNOWN"
else:
setup_file = rs.setup_file
setup_code = _get_setup_code(rs.setup_file)
ncfile.attrs.update(
date_created=datetime.datetime.today().isoformat(),
veros_version=veros_version,
setup_identifier=state.settings.identifier,
setup_description=state.settings.description,
setup_settings=json.dumps(state.settings.todict()),
setup_file=setup_file,
setup_code=setup_code,
)
dimensions = dict(state.dimensions)
if extra_dimensions is not None:
dimensions.update(extra_dimensions)
for dim in dimensions:
# time steps are peeled off explicitly
if dim in variables.TIMESTEPS:
continue
if dim in state.var_meta:
var = state.var_meta[dim]
# skip inactive dimensions
if not var.active:
continue
var_data = getattr(state.variables, dim)
else:
# create dummy variable for dimensions without data
var = variables.Variable(dim, (dim,), time_dependent=False)
var_data = np.arange(dimensions[dim])
dimsize = variables.get_shape(dimensions, var.dims[::-1], include_ghosts=False, local=False)[0]
ncfile.dimensions[dim] = dimsize
initialize_variable(state, dim, var, ncfile)
write_variable(state, dim, var, var_data, ncfile)
if create_time_dimension:
ncfile.dimensions["Time"] = None
nc_dim_var_time = ncfile.create_variable("Time", ("Time",), float)
nc_dim_var_time.attrs.update(
long_name="Time",
units="days",
time_origin="01-JAN-1900 00:00:00",
)
def initialize_variable(state, key, var, ncfile):
if var.dims is None:
dims = ()
else:
dims = tuple(d for d in var.dims if d in ncfile.dimensions)
if var.time_dependent and "Time" in ncfile.dimensions:
dims += ("Time",)
if key in ncfile.variables:
logger.warning(f"Variable {key} already initialized")
return
kwargs = {}
if rs.hdf5_gzip_compression and runtime_state.proc_num == 1:
kwargs.update(compression="gzip", compression_opts=1)
chunksize = [
variables.get_shape(state.dimensions, (d,), local=True, include_ghosts=False)[0] if d in state.dimensions else 1
for d in dims
]
dtype = var.dtype
if dtype is None:
dtype = rs.float_type
elif dtype == "bool":
dtype = "uint8"
fillvalue = variables.get_fill_value(dtype)
# transpose all dimensions in netCDF output (convention in most ocean models)
v = ncfile.create_variable(key, dims[::-1], dtype, fillvalue=fillvalue, chunks=tuple(chunksize[::-1]), **kwargs)
v.missing_value = fillvalue
v.attrs.update(long_name=var.name, units=var.units, **var.extra_attributes)
def advance_time(time_value, ncfile):
current_time_step = len(ncfile.variables["Time"])
ncfile.resize_dimension("Time", current_time_step + 1)
ncfile.variables["Time"][current_time_step] = time_value
def add_dimension(dim, dim_size, ncfile):
ncfile.dimensions[dim] = int(dim_size)
def write_variable(state, key, var, var_data, ncfile, time_step=-1):
var_data = var_data * var.scale
gridmask = var.get_mask(state.settings, state.variables)
if gridmask is not None:
newaxes = (slice(None),) * gridmask.ndim + (np.newaxis,) * (var_data.ndim - gridmask.ndim)
var_data = np.where(gridmask.astype("bool")[newaxes], var_data, variables.get_fill_value(var_data.dtype))
if var.dims:
tmask = tuple(state.variables.tau if dim in variables.TIMESTEPS else slice(None) for dim in var.dims)
var_data = variables.remove_ghosts(var_data, var.dims)[tmask].T
var_obj = ncfile.variables[key]
nx, ny = state.dimensions["xt"], state.dimensions["yt"]
chunk, _ = distributed.get_chunk_slices(nx, ny, var_obj.dimensions)
if "Time" in var_obj.dimensions:
assert var_obj.dimensions[0] == "Time"
chunk = (time_step,) + chunk[1:]
var_obj[chunk] = var_data
@contextlib.contextmanager
def threaded_io(filepath, mode):
"""
If using IO threads, start a new thread to write the netCDF data to disk.
"""
import h5py
import h5netcdf
if rs.use_io_threads:
_wait_for_disk(filepath)
_io_locks[filepath].clear()
kwargs = dict()
if int(h5py.__version__.split(".")[0]) >= 3:
kwargs.update(decode_vlen_strings=True)
if runtime_state.proc_num > 1:
kwargs.update(driver="mpio", comm=rs.mpi_comm)
nc_dataset = h5netcdf.File(filepath, mode, **kwargs)
try:
yield nc_dataset
finally:
if rs.use_io_threads:
threading.Thread(target=_write_to_disk, args=(nc_dataset, filepath)).start()
else:
_write_to_disk(nc_dataset, filepath)
_io_locks = {}
def _add_to_locks(file_id):
"""
If there is no lock for file_id, create one
"""
if file_id not in _io_locks:
_io_locks[file_id] = threading.Event()
_io_locks[file_id].set()
def _wait_for_disk(file_id):
"""
Wait for the lock of file_id to be released
"""
logger.debug(f"Waiting for lock {file_id} to be released")
_add_to_locks(file_id)
lock_released = _io_locks[file_id].wait(rs.io_timeout)
if not lock_released:
raise RuntimeError("Timeout while waiting for disk IO to finish")
def _write_to_disk(ncfile, file_id):
"""
Sync netCDF data to disk, close file handle, and release lock.
May run in a separate thread.
"""
try:
ncfile.close()
finally:
if rs.use_io_threads and file_id is not None:
_io_locks[file_id].set()
import sys
import warnings
LOGLEVELS = ("trace", "debug", "info", "warning", "error")
def _inject_proc_rank(record):
from veros import runtime_state
return record["extra"].update(proc_rank=runtime_state.proc_rank)
def setup_logging(loglevel="info", stream_sink=sys.stdout, log_all_processes=False):
from loguru import logger
handler_conf = dict(
sink=stream_sink,
level=loglevel.upper(),
colorize=sys.stdout.isatty(),
)
if not hasattr(logger, "diagnostic"):
logger.level("DIAGNOSTIC", no=45)
logger.level("TRACE", color="<dim>")
logger.level("DEBUG", color="<dim><cyan>")
logger.level("INFO", color="")
logger.level("SUCCESS", color="<dim><green>")
logger.level("WARNING", color="<yellow>")
logger.level("ERROR", color="<bold><red>")
logger.level("DIAGNOSTIC", color="<bold><yellow>")
logger.level("CRITICAL", color="<bold><red><WHITE>")
logger = logger.patch(_inject_proc_rank)
if log_all_processes:
handler_conf.update(format="{extra[proc_rank]} | <level>{message}</level>")
else:
handler_conf.update(format="<level>{message}</level>", filter=lambda record: record["extra"]["proc_rank"] == 0)
def diagnostic(_, message, *args, **kwargs):
logger.opt(depth=1).log("DIAGNOSTIC", message, *args, **kwargs)
logger.__class__.diagnostic = diagnostic
def showwarning(message, cls, source, lineno, *args):
logger.warning(
"{warning}: {message} ({source}:{lineno})",
message=message,
warning=cls.__name__,
source=source,
lineno=lineno,
)
warnings.showwarning = showwarning
logger.configure(handlers=[handler_conf])
logger.enable("veros")
return logger
from collections import namedtuple
from veros.variables import Variable
from veros.settings import Setting
VerosPlugin = namedtuple(
"VerosPlugin",
[
"name",
"module",
"setup_entrypoint",
"run_entrypoint",
"settings",
"variables",
"dimensions",
"diagnostics",
],
)
def load_plugin(module):
from veros.diagnostics.base import VerosDiagnostic
modname = module.__name__
if not hasattr(module, "__VEROS_INTERFACE__"):
raise RuntimeError(f"module {modname} is not a valid Veros plugin")
interface = module.__VEROS_INTERFACE__
setup_entrypoint = interface.get("setup_entrypoint")
if not callable(setup_entrypoint):
raise RuntimeError(f"module {modname} is missing a valid setup entrypoint")
run_entrypoint = interface.get("run_entrypoint")
if not callable(run_entrypoint):
raise RuntimeError(f"module {modname} is missing a valid run entrypoint")
name = interface.get("name", modname)
settings = interface.get("settings", {})
for setting, val in settings.items():
if not isinstance(val, Setting):
raise TypeError(f"got unexpected type {type(val)} for setting {setting}")
variables = interface.get("variables", {})
for variable, val in variables.items():
if not isinstance(val, Variable):
raise TypeError(f"got unexpected type {type(val)} for variable {variable}")
dimensions = interface.get("dimensions", {})
for dim, val in dimensions.items():
if not isinstance(val, (str, int)):
raise TypeError(f"got unexpected type {type(val)} for dimension {dim}")
diagnostics = interface.get("diagnostics", [])
for diagnostic in diagnostics:
if not issubclass(diagnostic, VerosDiagnostic):
raise TypeError(f"got unexpected type {type(diagnostic)} for diagnostic {diagnostic}")
return VerosPlugin(
name=name,
module=module,
setup_entrypoint=setup_entrypoint,
run_entrypoint=run_entrypoint,
settings=settings,
variables=variables,
dimensions=dimensions,
diagnostics=diagnostics,
)
import sys
import functools
from time import perf_counter
try:
import tqdm
except ImportError:
has_tqdm = False
else:
has_tqdm = True
from veros import logger, time, logs, runtime_settings as rs, runtime_state as rst
BAR_FORMAT = (
" Current iteration: {iteration:<5} ({time:.2f}/{total:.2f}{unit} | {percentage:>4.1f}% | "
"{rate:.2f}{rate_unit} | {eta:.1f}{eta_unit} left)"
)
class LoggingProgressBar:
"""A simple progress report to logger.info
Serves as a fallback where TQDM is not available or not feasible (writing to a file,
in multiprocessing contexts).
"""
def __init__(self, total, start_time=0, start_iteration=0, time_unit="seconds"):
self._start_time = start_time
self._start_iteration = start_iteration
self._total = total
_, self._time_unit = time.format_time(total)
def __enter__(self):
self._start = perf_counter()
self._iteration = self._start_iteration
self._time = self._start_time
self.flush()
return self
def __exit__(self, *args, **kwargs):
pass
def advance_time(self, amount, *args, **kwargs):
self._iteration += 1
self._time += amount
self.flush()
def flush(self):
report_time = time.convert_time(self._time, "seconds", self._time_unit)
total_time = time.convert_time(self._total, "seconds", self._time_unit)
if self._time > self._start_time:
rate_in_seconds = (perf_counter() - self._start) / (self._time - self._start_time)
else:
rate_in_seconds = 0
rate_in_seconds_per_year = rate_in_seconds / time.convert_time(1, "seconds", "years")
rate, rate_unit = time.format_time(rate_in_seconds_per_year)
eta, eta_unit = time.format_time((self._total - self._time) * rate_in_seconds)
if self._start_time < self._total:
percentage = 100 * (self._time - self._start_time) / (self._total - self._start_time)
else:
percentage = 100
logger.info(
BAR_FORMAT,
time=report_time,
total=total_time,
unit=self._time_unit[0],
percentage=percentage,
iteration=self._iteration,
rate=rate,
rate_unit=f"{rate_unit[0]}/(model year)",
eta=eta,
eta_unit=eta_unit[0],
)
class FancyProgressBar:
"""A fancy progress bar based on TQDM that stays at the bottom of the terminal."""
def __init__(self, total, start_time=0, start_iteration=0, time_unit="seconds"):
self._time = self._start_time = start_time
self._iteration = self._start_iteration = start_iteration
self._total = total
total_runlen, time_unit = time.format_time(total)
self._time_unit = time_unit
class _VerosTQDM(tqdm.tqdm):
"""Stripped down version of tqdm.tqdm
We only need TQDM to handle dynamic updates to the progress indicator.
"""
def __init__(self, *args, **kwargs):
kwargs.update(leave=True)
super().__init__(*args, **kwargs)
@property
def format_dict(other):
report_time = time.convert_time(self._time, "seconds", self._time_unit)
total_time = time.convert_time(self._total, "seconds", self._time_unit)
if self._start_time < self._total:
percentage = 100 * (self._time - self._start_time) / (self._total - self._start_time)
else:
percentage = 100
d = super().format_dict
if d["elapsed"] > 0:
if self._time > self._start_time:
rate_in_seconds = d["elapsed"] / (self._time - self._start_time)
else:
rate_in_seconds = 0
rate_in_seconds_per_year = rate_in_seconds / time.convert_time(1, "seconds", "years")
rate, rate_unit = time.format_time(rate_in_seconds_per_year)
eta, eta_unit = time.format_time((self._total - self._time) * rate_in_seconds)
else:
rate, rate_unit = 0, "s"
eta, eta_unit = 0, "s"
d.update(
iteration=self._iteration,
time=report_time,
total=total_time,
unit=self._time_unit[0],
percentage=percentage,
rate=rate,
rate_unit=f"{rate_unit[0]}/(model year)",
eta=eta,
eta_unit=eta_unit[0],
)
return d
def format_meter(other, *args, bar_format, **kwargs):
return bar_format.format(**kwargs)
self._pbar = _VerosTQDM(file=sys.stdout, bar_format=BAR_FORMAT)
def __enter__(self, *args, **kwargs):
self._iteration = self._start_iteration
self._time = self._start_time
logs.setup_logging(
loglevel=rs.loglevel, stream_sink=functools.partial(self._pbar.write, file=sys.stdout, end="")
)
self._pbar.__enter__(*args, **kwargs)
return self
def __exit__(self, *args, **kwargs):
logs.setup_logging(loglevel=rs.loglevel)
self._pbar.__exit__(*args, **kwargs)
def advance_time(self, amount):
self._iteration += 1
self._time += amount
self.flush()
def flush(self):
self._pbar.refresh()
def get_progress_bar(state, use_tqdm=None):
if use_tqdm is None:
use_tqdm = sys.stdout.isatty() and rst.proc_num == 1 and has_tqdm
if use_tqdm and not has_tqdm:
raise RuntimeError("tqdm failed to import. Try `pip install tqdm` or set use_tqdm=False.")
kwargs = dict(
total=state.settings.runlen + float(state.variables.time),
start_time=float(state.variables.time),
start_iteration=int(state.variables.itt),
)
if use_tqdm:
pbar = FancyProgressBar(**kwargs)
else:
pbar = LoggingProgressBar(**kwargs)
return pbar
import os
from contextlib import contextmanager
from collections import defaultdict
import importlib.util
from veros import logger, runtime_settings, runtime_state, timer
from veros.state import get_default_state, resize_dimension
from veros.variables import get_shape
# all variables that are re-named or unique to Veros
VEROS_TO_PYOM_VAR = dict(
# do not exist in pyom
time=None,
prho=None,
land_map=None,
isle=None,
isle_boundary_mask=None,
line_dir_south_mask=None,
line_dir_east_mask=None,
line_dir_north_mask=None,
line_dir_west_mask=None,
ssh=None,
)
# all setting that are re-named or unique to Veros
VEROS_TO_PYOM_SETTING = dict(
# do not exist in pyom
identifier=None,
description=None,
enable_noslip_lateral=None,
restart_input_filename=None,
restart_output_filename=None,
restart_frequency=None,
kappaH_min=None,
enable_kappaH_profile=None,
enable_Prandtl_tke=None,
Prandtl_tke0=None,
biharmonic_friction_cosPower=None,
)
# these are read-only
CONSTANTS = ("pi", "radius", "degtom", "mtodeg", "omega", "rho_0", "grav")
INIT_STREAM_VARS = ("psin", "dpsin", "line_psin")
def _load_fortran_module(module, path):
spec = importlib.util.spec_from_file_location(module, path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
def load_pyom(pyom_lib):
try:
pyom_obj = _load_fortran_module("pyOM_code_MPI", pyom_lib)
has_mpi = True
except ImportError:
pyom_obj = _load_fortran_module("pyOM_code", pyom_lib)
has_mpi = False
if runtime_state.proc_num > 1 and not has_mpi:
raise RuntimeError("Given PyOM2 library was not built with MPI support")
return pyom_obj
@contextmanager
def suppress_stdout(stdout_fd=1):
old_stdout = os.dup(stdout_fd)
with open(os.devnull, "wb") as void:
os.dup2(void.fileno(), stdout_fd)
try:
yield
finally:
with os.fdopen(old_stdout, "wb") as std:
os.dup2(std.fileno(), stdout_fd)
def pyom_from_state(state, pyom_obj, ignore_attrs=None, init_streamfunction=None):
"""Force-updates internal PyOM library state to match given Veros state."""
if ignore_attrs is None:
ignore_attrs = []
pyom_modules = (
pyom_obj.main_module,
pyom_obj.isoneutral_module,
pyom_obj.idemix_module,
pyom_obj.tke_module,
pyom_obj.eke_module,
)
def set_fortran_attr(attr, val):
# fortran interface is all lower-case
attr = attr.lower()
for module in pyom_modules:
if hasattr(module, attr):
setattr(module, attr, val)
break
else:
raise RuntimeError(f"Could not set attribute {attr} on Fortran library")
# settings
for setting, val in state.settings.items():
setting = VEROS_TO_PYOM_SETTING.get(setting, setting)
if setting is None or setting in ignore_attrs or setting in CONSTANTS:
continue
set_fortran_attr(setting, val)
_override_settings(pyom_obj)
# allocate variables
if runtime_state.proc_num > 1:
pyom_obj.my_mpi_init(runtime_settings.mpi_comm.py2f())
else:
pyom_obj.my_mpi_init(0)
pyom_obj.pe_decomposition()
pyom_obj.allocate_main_module()
pyom_obj.allocate_isoneutral_module()
pyom_obj.allocate_tke_module()
pyom_obj.allocate_eke_module()
pyom_obj.allocate_idemix_module()
# set variables
for var, val in state.variables.items():
var = VEROS_TO_PYOM_VAR.get(var, var)
if var is None or var in ignore_attrs:
continue
if var in INIT_STREAM_VARS:
continue
set_fortran_attr(var, val)
if init_streamfunction is None:
init_streamfunction = state.settings.enable_streamfunction
if init_streamfunction:
with suppress_stdout():
pyom_obj.streamfunction_init()
for var in INIT_STREAM_VARS:
set_fortran_attr(var, state.variables.get(var))
# correct for 1-based indexing
pyom_obj.main_module.tau += 1
pyom_obj.main_module.taup1 += 1
pyom_obj.main_module.taum1 += 1
# diagnostics
diag_settings = (
("cfl_monitor", "output_frequency", "ts_monint"),
("tracer_monitor", "output_frequency", "trac_cont_int"),
("snapshot", "output_frequency", "snapint"),
("averages", "output_frequency", "aveint"),
("averages", "sampling_frequency", "avefreq"),
("overturning", "output_frequency", "overint"),
("overturning", "sampling_frequency", "overfreq"),
("energy", "output_frequency", "energint"),
("energy", "sampling_frequency", "energfreq"),
)
for diag, param, attr in diag_settings:
if diag in state.diagnostics:
set_fortran_attr(attr, getattr(diag, param))
return pyom_obj
def _override_settings(pyom_obj):
"""Manually force some settings to ensure compatibility."""
m = pyom_obj.main_module
m.n_pes_i, m.n_pes_j = runtime_settings.num_proc
# define processor boundary idx (1-based)
ipx, ipy = runtime_state.proc_idx
m.is_pe = (m.nx // m.n_pes_i) * ipx + 1
m.ie_pe = (m.nx // m.n_pes_i) * (ipx + 1)
m.js_pe = (m.ny // m.n_pes_j) * ipy + 1
m.je_pe = (m.ny // m.n_pes_j) * (ipy + 1)
# force settings that are not supported by Veros
idm = pyom_obj.idemix_module
eke = pyom_obj.eke_module
m.enable_hydrostatic = True
m.congr_epsilon = 1e-12
m.congr_max_iterations = 10_000
m.enable_congrad_verbose = False
m.enable_free_surface = True
eke.enable_eke_leewave_dissipation = False
idm.enable_idemix_m2 = False
idm.enable_idemix_niw = False
return pyom_obj
def state_from_pyom(pyom_obj):
from veros.core.operators import numpy as npx
state = get_default_state()
pyom_modules = (
pyom_obj.main_module,
pyom_obj.isoneutral_module,
pyom_obj.idemix_module,
pyom_obj.tke_module,
pyom_obj.eke_module,
)
def get_fortran_attr(attr):
# fortran interface is all lower-case
attr = attr.lower()
for module in pyom_modules:
if hasattr(module, attr):
return getattr(module, attr)
else:
raise RuntimeError(f"Could not get attribute {attr} from Fortran library")
with state.settings.unlock():
for setting in state.settings.fields():
setting = VEROS_TO_PYOM_SETTING.get(setting, setting)
if setting is None:
continue
state.settings.update({setting: get_fortran_attr(setting)})
state.initialize_variables()
with state.variables.unlock():
if state.settings.enable_streamfunction:
resize_dimension(state, "isle", int(pyom_obj.main_module.nisle))
state.variables.isle = npx.arange(state.dimensions["isle"])
for var, val in state.variables.items():
var = VEROS_TO_PYOM_VAR.get(var, var)
if var is None:
continue
try:
new_val = get_fortran_attr(var)
except RuntimeError:
continue
if new_val is None:
continue
try:
new_val = npx.broadcast_to(new_val, val.shape)
except ValueError:
raise ValueError(f"variable {var} has incompatible shapes: {val.shape}, {new_val.shape}")
state.variables.update({var: new_val})
return state
def setup_pyom(pyom_obj, set_parameter, set_grid, set_coriolis, set_topography, set_initial_conditions, set_forcing):
if runtime_state.proc_num > 1:
pyom_obj.my_mpi_init(runtime_settings.mpi_comm.py2f())
else:
pyom_obj.my_mpi_init(0)
set_parameter(pyom_obj)
pyom_obj.pe_decomposition()
pyom_obj.allocate_main_module()
pyom_obj.allocate_isoneutral_module()
pyom_obj.allocate_tke_module()
pyom_obj.allocate_eke_module()
pyom_obj.allocate_idemix_module()
set_grid(pyom_obj)
pyom_obj.calc_grid()
set_coriolis(pyom_obj)
pyom_obj.calc_beta()
set_topography(pyom_obj)
pyom_obj.calc_topo()
pyom_obj.calc_spectral_topo()
set_initial_conditions(pyom_obj)
pyom_obj.calc_initial_conditions()
pyom_obj.streamfunction_init()
set_forcing(pyom_obj)
pyom_obj.check_isoneutral_slope_crit()
def run_pyom(pyom_obj, set_forcing, after_timestep=None):
timers = defaultdict(timer.Timer)
f = pyom_obj
m = pyom_obj.main_module
idm = pyom_obj.idemix_module
ekm = pyom_obj.eke_module
tkm = pyom_obj.tke_module
logger.info(f"Starting integration for {float(m.runlen):.2e}s")
m.time = 0.0
while m.time < m.runlen:
logger.info(f"Current iteration: {m.itt}")
with timers["main"]:
set_forcing(pyom_obj)
if idm.enable_idemix:
f.set_idemix_parameter()
f.set_eke_diffusivities()
f.set_tke_diffusivities()
with timers["momentum"]:
f.momentum()
with timers["temperature"]:
f.thermodynamics()
if ekm.enable_eke or tkm.enable_tke or idm.enable_idemix:
f.calculate_velocity_on_wgrid()
with timers["eke"]:
if ekm.enable_eke:
f.integrate_eke()
with timers["idemix"]:
if idm.enable_idemix:
f.integrate_idemix()
with timers["tke"]:
if tkm.enable_tke:
f.integrate_tke()
"""
Main boundary exchange
for density, temp and salt this is done in integrate_tempsalt.f90
"""
f.border_exchg_xyz(
m.is_pe - m.onx, m.ie_pe + m.onx, m.js_pe - m.onx, m.je_pe + m.onx, m.u[:, :, :, m.taup1 - 1], m.nz
)
f.setcyclic_xyz(
m.is_pe - m.onx, m.ie_pe + m.onx, m.js_pe - m.onx, m.je_pe + m.onx, m.u[:, :, :, m.taup1 - 1], m.nz
)
f.border_exchg_xyz(
m.is_pe - m.onx, m.ie_pe + m.onx, m.js_pe - m.onx, m.je_pe + m.onx, m.v[:, :, :, m.taup1 - 1], m.nz
)
f.setcyclic_xyz(
m.is_pe - m.onx, m.ie_pe + m.onx, m.js_pe - m.onx, m.je_pe + m.onx, m.v[:, :, :, m.taup1 - 1], m.nz
)
if tkm.enable_tke:
f.border_exchg_xyz(
m.is_pe - m.onx,
m.ie_pe + m.onx,
m.js_pe - m.onx,
m.je_pe + m.onx,
tkm.tke[:, :, :, m.taup1 - 1],
m.nz,
)
f.setcyclic_xyz(
m.is_pe - m.onx,
m.ie_pe + m.onx,
m.js_pe - m.onx,
m.je_pe + m.onx,
tkm.tke[:, :, :, m.taup1 - 1],
m.nz,
)
if ekm.enable_eke:
f.border_exchg_xyz(
m.is_pe - m.onx,
m.ie_pe + m.onx,
m.js_pe - m.onx,
m.je_pe + m.onx,
ekm.eke[:, :, :, m.taup1 - 1],
m.nz,
)
f.setcyclic_xyz(
m.is_pe - m.onx,
m.ie_pe + m.onx,
m.js_pe - m.onx,
m.je_pe + m.onx,
ekm.eke[:, :, :, m.taup1 - 1],
m.nz,
)
if idm.enable_idemix:
f.border_exchg_xyz(
m.is_pe - m.onx,
m.ie_pe + m.onx,
m.js_pe - m.onx,
m.je_pe + m.onx,
idm.e_iw[:, :, :, m.taup1 - 1],
m.nz,
)
f.setcyclic_xyz(
m.is_pe - m.onx,
m.ie_pe + m.onx,
m.js_pe - m.onx,
m.je_pe + m.onx,
idm.e_iw[:, :, :, m.taup1 - 1],
m.nz,
)
# diagnose vertical velocity at taup1
f.vertical_velocity()
# diagnose isoneutral streamfunction regardless of output settings
f.isoneutral_diag_streamfunction()
# shift time
m.itt += 1
m.time += m.dt_tracer
if callable(after_timestep):
after_timestep(pyom_obj)
orig_taum1 = int(m.taum1)
m.taum1 = m.tau
m.tau = m.taup1
m.taup1 = orig_taum1
# NOTE: benchmarks parse this, do not change / remove
logger.debug("Time step took {}s", timers["main"].last_time)
logger.debug("Timing summary:")
logger.debug(" setup time summary = {}s", timers["setup"].total_time)
logger.debug(" main loop time summary = {}s", timers["main"].total_time)
logger.debug(" momentum = {}s", timers["momentum"].total_time)
logger.debug(" thermodynamics = {}s", timers["temperature"].total_time)
logger.debug(" EKE = {}s", timers["eke"].total_time)
logger.debug(" IDEMIX = {}s", timers["idemix"].total_time)
logger.debug(" TKE = {}s", timers["tke"].total_time)
def _generate_random_var(state, var):
import numpy as onp
meta = state.var_meta[var]
shape = get_shape(state.dimensions, meta.dims)
global_shape = get_shape(state.dimensions, meta.dims, local=False)
if var == "kbot":
val = onp.zeros(shape)
val[2:-2, 2:-2] = onp.random.randint(1, state.dimensions["zt"], size=(shape[0] - 4, shape[1] - 4))
island_mask = onp.random.choice(val[3:-3, 3:-3].size, size=10)
val[3:-3, 3:-3].flat[island_mask] = 0
return val
if var in ("dxt", "dxu", "dyt", "dyu"):
if state.settings.coord_degree:
val = 80 / global_shape[0] * (1 + 1e-2 * onp.random.randn(*shape))
else:
val = 10_000e3 / global_shape[0] * (1 + 1e-2 * onp.random.randn(*shape))
return val
if var in ("dzt", "dzw"):
val = 6000 / global_shape[0] * (1 + 1e-2 * onp.random.randn(*shape))
return val
if onp.issubdtype(onp.dtype(meta.dtype), onp.floating):
val = onp.random.randn(*shape)
if var in ("salt",):
val = 35 + val
return val
if onp.issubdtype(onp.dtype(meta.dtype), onp.integer):
val = onp.random.randint(0, 100, size=shape)
return val
if onp.issubdtype(onp.dtype(meta.dtype), onp.bool_):
return onp.random.randint(0, 1, size=shape, dtype="bool")
raise TypeError(f"got unrecognized dtype: {meta.dtype}")
def get_random_state(pyom2_lib=None, extra_settings=None):
"""Generates random Veros and PyOM states (for testing)"""
from veros.core import numerics, external
if extra_settings is None:
extra_settings = {}
state = get_default_state()
settings = state.settings
with settings.unlock():
settings.update(extra_settings)
state.initialize_variables()
state.variables.__locked__ = False # leave variables unlocked
for var, meta in state.var_meta.items():
if not meta.active:
continue
if var in ("tau", "taup1", "taum1"):
continue
val = _generate_random_var(state, var)
setattr(state.variables, var, val)
# ensure that masks and geometries are consistent with grid spacings
numerics.calc_grid(state)
numerics.calc_topo(state)
if settings.enable_streamfunction:
external.streamfunction_init(state)
if pyom2_lib is None:
return state
pyom_obj = load_pyom(pyom2_lib)
pyom_obj = pyom_from_state(state, pyom_obj)
return state, pyom_obj
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment