Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
import numpy as np
from mpi4py import MPI
from veros import runtime_settings as rs, runtime_state as rst
from veros.distributed import scatter
global_arr = np.array(
[
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
[1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0],
]
)
if rst.proc_num == 1:
import sys
comm = MPI.COMM_SELF.Spawn(sys.executable, args=["-m", "mpi4py", sys.argv[-1]], maxprocs=4)
res = np.empty((6, 6))
proc_slices = (
(slice(None, -2), slice(None, -2)),
(slice(2, None), slice(None, -2)),
(slice(None, -2), slice(2, None)),
(slice(2, None), slice(2, None)),
)
for proc, idx in enumerate(proc_slices):
comm.Recv(res, proc)
np.testing.assert_array_equal(res, global_arr[idx])
else:
rs.num_proc = (2, 2)
assert rst.proc_num == 4
from veros.core.operators import numpy as npx
dimensions = dict(xt=4, yt=4)
if rst.proc_rank == 0:
a = npx.array(global_arr)
else:
a = npx.empty((6, 6))
b = scatter(a, dimensions, ("xt", "yt"))
rs.mpi_comm.Get_parent().Send(np.array(b), 0)
import sys
import numpy as onp
from mpi4py import MPI
from veros import runtime_settings as rs, runtime_state as rst
rs.diskless_mode = True
if rst.proc_num > 1:
rs.num_proc = (2, 2)
assert rst.proc_num == 4
from veros.state import get_default_state, resize_dimension # noqa: E402
from veros.distributed import gather # noqa: E402
from veros.core.operators import numpy as npx, update, at # noqa: E402
from veros.core.external.solvers import get_linear_solver # noqa: E402
def get_inputs():
state = get_default_state()
settings = state.settings
with settings.unlock():
settings.nx = 100
settings.ny = 40
settings.nz = 1
settings.enable_cyclic_x = True
settings.enable_streamfunction = True
state.initialize_variables()
resize_dimension(state, "isle", 1)
vs = state.variables
nx_local, ny_local = settings.nx // rs.num_proc[0], settings.ny // rs.num_proc[1]
idx_global = (
slice(rst.proc_idx[0] * nx_local, (rst.proc_idx[0] + 1) * nx_local + 4),
slice(rst.proc_idx[1] * ny_local, (rst.proc_idx[1] + 1) * ny_local + 4),
Ellipsis,
)
with vs.unlock():
vs.dxt = update(vs.dxt, at[...], 10e3)
vs.dxu = update(vs.dxu, at[...], 10e3)
vs.dyt = update(vs.dyt, at[...], 10e3)
vs.dyu = update(vs.dyu, at[...], 10e3)
hr_global = (
1.0 / npx.linspace(500, 2000, settings.nx + 4)[:, None] * npx.ones((settings.nx + 4, settings.ny + 4))
)
vs.hur = hr_global[idx_global]
vs.hvr = hr_global[idx_global]
vs.cosu = update(vs.cosu, at[...], 1)
vs.cost = update(vs.cost, at[...], 1)
boundary_mask = npx.ones((settings.nx + 4, settings.ny + 4), dtype="bool")
boundary_mask = update(boundary_mask, at[:50, :2], 0)
boundary_mask = update(boundary_mask, at[20:30, 20:30], 0)
vs.isle_boundary_mask = boundary_mask[idx_global]
rhs = npx.ones_like(vs.hur)
x0 = npx.zeros_like(vs.hur)
return state, rhs, x0
if rst.proc_num == 1:
comm = MPI.COMM_SELF.Spawn(sys.executable, args=["-m", "mpi4py", sys.argv[-1]], maxprocs=4)
try:
state, rhs, x0 = get_inputs()
sol = get_linear_solver(state)
psi = sol.solve(state, rhs, x0)
except Exception as exc:
print(str(exc))
comm.Abort(1)
raise
other_psi = onp.empty_like(psi)
comm.Recv(other_psi, 0)
onp.testing.assert_allclose(psi, other_psi)
else:
state, rhs, x0 = get_inputs()
sol = get_linear_solver(state)
psi = sol.solve(state, rhs, x0)
psi_global = gather(psi, state.dimensions, ("xt", "yt"))
if rst.proc_rank == 0:
rs.mpi_comm.Get_parent().Send(onp.array(psi_global), 0)
import pytest
import numpy as np
from veros.state import get_default_state, resize_dimension
@pytest.fixture
def solver_state(cyclic, problem):
state = get_default_state()
settings = state.settings
with settings.unlock():
settings.nx = 400
settings.ny = 200
settings.nz = 1
settings.dt_tracer = 1800
settings.dt_mom = 1800
settings.enable_cyclic_x = cyclic
settings.enable_streamfunction = problem == "streamfunction"
state.initialize_variables()
resize_dimension(state, "isle", 1)
vs = state.variables
with vs.unlock():
vs.dxt = 10e3 * np.ones(settings.nx + 4)
vs.dxu = 10e3 * np.ones(settings.nx + 4)
vs.dyt = 10e3 * np.ones(settings.ny + 4)
vs.dyu = 10e3 * np.ones(settings.ny + 4)
vs.hur = 1.0 / np.linspace(500, 2000, settings.nx + 4)[:, None] * np.ones((settings.nx + 4, settings.ny + 4))
vs.hvr = 1.0 / np.linspace(500, 2000, settings.ny + 4)[None, :] * np.ones((settings.nx + 4, settings.ny + 4))
vs.hu = 1.0 / vs.hur
vs.hv = 1.0 / vs.hvr
vs.cosu = np.ones(settings.ny + 4)
vs.cost = np.ones(settings.ny + 4)
boundary_mask = np.ones((settings.nx + 4, settings.ny + 4), dtype="bool")
boundary_mask[:100, :2] = 0
boundary_mask[50:100, 50:100] = 0
if settings.enable_streamfunction:
vs.isle_boundary_mask = boundary_mask
maskT = np.zeros((settings.nx + 4, settings.ny + 4, settings.nz), dtype="bool")
if settings.enable_cyclic_x:
maskT[:, 2:-2, 0] = boundary_mask[:, 2:-2]
else:
maskT[2:-2, 2:-2, 0] = boundary_mask[2:-2, 2:-2]
vs.maskT = maskT
return state
def assert_solution(state, rhs, sol, boundary_val=None, tol=1e-8):
from veros.core.external.solvers.scipy import SciPySolver
matrix, boundary_mask = SciPySolver._assemble_poisson_matrix(state)
if boundary_val is None:
boundary_val = sol
rhs = np.where(boundary_mask, rhs, boundary_val)
rhs_sol = matrix @ sol.reshape(-1)
np.testing.assert_allclose(rhs_sol, rhs.flatten(), atol=0, rtol=tol)
@pytest.mark.parametrize("cyclic", [True, False])
@pytest.mark.parametrize("solver", ["scipy", "scipy_jax", "petsc"])
@pytest.mark.parametrize("problem", ["streamfunction", "pressure"])
def test_solver(solver, solver_state, cyclic, problem):
from veros import runtime_settings
from veros.core.operators import numpy as npx
if solver == "scipy":
from veros.core.external.solvers.scipy import SciPySolver
solver_class = SciPySolver
elif solver == "scipy_jax":
if runtime_settings.backend != "jax":
pytest.skip("scipy_jax solver requires JAX")
from veros.core.external.solvers.scipy_jax import JAXSciPySolver
solver_class = JAXSciPySolver
elif solver == "petsc":
petsc_mod = pytest.importorskip("veros.core.external.solvers.petsc_")
solver_class = petsc_mod.PETScSolver
else:
raise ValueError("unknown solver")
settings = solver_state.settings
rhs = npx.ones((settings.nx + 4, settings.ny + 4))
x0 = npx.asarray(np.random.rand(settings.nx + 4, settings.ny + 4))
sol = solver_class(solver_state).solve(solver_state, rhs, x0)
assert_solution(solver_state, rhs, sol, tol=1e-8)
sol = solver_class(solver_state).solve(solver_state, rhs, x0, boundary_val=10)
assert_solution(solver_state, rhs, sol, tol=1e-8, boundary_val=10)
import pytest
from veros.plugins import load_plugin
from veros.routines import veros_routine
from veros.state import get_default_state
from veros.variables import Variable
from veros.settings import Setting
@pytest.fixture
def fake_plugin():
class FakePlugin:
pass
def run_setup(state):
plugin._setup_ran = True
def run_main(state):
plugin._main_ran = True
plugin = FakePlugin()
plugin.__name__ = "foobar"
plugin._setup_ran = False
plugin._main_ran = False
plugin.__VEROS_INTERFACE__ = {
"name": "foo",
"setup_entrypoint": run_setup,
"run_entrypoint": run_main,
"settings": dict(mydimsetting=Setting(15, int, "bar")),
"variables": dict(myvar=Variable("myvar", ("xt", "yt", "mydim"))),
"dimensions": dict(mydim="mydimsetting"),
"diagnostics": [],
}
yield plugin
def test_load_plugin(fake_plugin):
plugin_interface = load_plugin(fake_plugin)
assert plugin_interface.name == "foo"
def test_state_plugin(fake_plugin):
plugin_interface = load_plugin(fake_plugin)
state = get_default_state(plugin_interfaces=plugin_interface)
assert "mydimsetting" in state.settings
assert "mydim" in state.dimensions
assert state.dimensions["mydim"] == state.settings.mydimsetting
state.initialize_variables()
assert "myvar" in state.variables
assert state.variables.myvar.shape == (4, 4, state.settings.mydimsetting)
def test_run_plugin(fake_plugin):
from veros.setups.acc_basic import ACCBasicSetup
class FakeSetup(ACCBasicSetup):
__veros_plugins__ = (fake_plugin,)
@veros_routine
def set_diagnostics(self, state):
pass
setup = FakeSetup(override=dict(dt_tracer=100, runlen=100))
assert not fake_plugin._setup_ran
setup.setup()
assert fake_plugin._setup_ran
assert not fake_plugin._main_ran
setup.run()
assert fake_plugin._main_ran
import sys
import re
import time
import platform
import pytest
class Dummy:
pass
@pytest.mark.xfail(platform.system() == "Darwin", reason="Flaky on OSX")
def test_progress_format(capsys):
from veros.logs import setup_logging
setup_logging(stream_sink=sys.stdout)
from veros.progress import get_progress_bar
dummy_state = Dummy()
dummy_state.settings = Dummy()
dummy_state.variables = Dummy()
dummy_state.settings.runlen = 8000
dummy_state.variables.time = 2000
dummy_state.variables.itt = 2
with get_progress_bar(dummy_state, use_tqdm=False) as pbar:
for _ in range(8):
time.sleep(0.1)
pbar.advance_time(1000)
captured_log = capsys.readouterr()
assert "Current iteration:" in captured_log.out
with get_progress_bar(dummy_state, use_tqdm=True) as pbar:
for _ in range(8):
time.sleep(0.1)
pbar.advance_time(1000)
captured_tqdm = capsys.readouterr()
assert "Current iteration:" in captured_tqdm.out
def sanitize(prog):
# remove rates and ETA (inconsistent)
prog = re.sub(r"\d+\.\d{2}[smh]/\(model year\)", "?s/(model year)", prog)
prog = re.sub(r"\d+\.\d[smh] left", "? left", prog)
prog = prog.replace("\r", "\n")
prog = prog.strip()
return prog
def deduplicate(prog):
# remove repeated identical lines
out = []
for line in prog.split("\n"):
if not out or out[-1] != line:
out.append(line)
return "\n".join(out)
assert sanitize(captured_log.out) == deduplicate(sanitize(captured_tqdm.out))
from functools import partial
import numpy as np
from veros import tools
from veros.routines import veros_routine
from veros.pyom_compat import load_pyom, pyom_from_state, run_pyom
from veros.setups.global_4deg import GlobalFourDegreeSetup
from test_base import compare_state
class GlobalFourDegreeTest(GlobalFourDegreeSetup):
@veros_routine
def set_parameter(self, state):
settings = state.settings
super().set_parameter(state)
settings.runlen = settings.dt_tracer * 100
settings.restart_output_filename = None
# do not exist in pyOM
settings.kappaH_min = 0.0
settings.enable_kappaH_profile = False
settings.enable_Prandtl_tke = True
@veros_routine
def set_forcing(self, state):
vs = state.variables
settings = state.settings
super().set_forcing(state)
vs.surface_taux = vs.surface_taux / settings.rho_0
vs.surface_tauy = vs.surface_tauy / settings.rho_0
@veros_routine
def set_diagnostics(self, state):
state.diagnostics.clear()
def set_forcing_pyom(pyom_obj, vs_state):
vs = vs_state.variables
m = pyom_obj.main_module
year_in_seconds = 360 * 86400.0
time = m.itt * m.dt_tracer
(n1, f1), (n2, f2) = tools.get_periodic_interval(time, year_in_seconds, year_in_seconds / 12.0, 12)
# wind stress
m.surface_taux[...] = (f1 * vs.taux[:, :, n1] + f2 * vs.taux[:, :, n2]) / m.rho_0
m.surface_tauy[...] = (f1 * vs.tauy[:, :, n1] + f2 * vs.tauy[:, :, n2]) / m.rho_0
# tke flux
t = pyom_obj.tke_module
if t.enable_tke:
t.forc_tke_surface[1:-1, 1:-1] = np.sqrt(
(0.5 * (m.surface_taux[1:-1, 1:-1] + m.surface_taux[:-2, 1:-1])) ** 2
+ (0.5 * (m.surface_tauy[1:-1, 1:-1] + m.surface_tauy[1:-1, :-2])) ** 2
) ** (3.0 / 2.0)
# heat flux : W/m^2 K kg/J m^3/kg = K m/s
cp_0 = 3991.86795711963
sst = f1 * vs.sst_clim[:, :, n1] + f2 * vs.sst_clim[:, :, n2]
qnec = f1 * vs.qnec[:, :, n1] + f2 * vs.qnec[:, :, n2]
qnet = f1 * vs.qnet[:, :, n1] + f2 * vs.qnet[:, :, n2]
m.forc_temp_surface[...] = (qnet + qnec * (sst - m.temp[:, :, -1, m.tau - 1])) * m.maskt[:, :, -1] / cp_0 / m.rho_0
# salinity restoring
t_rest = 30 * 86400.0
sss = f1 * vs.sss_clim[:, :, n1] + f2 * vs.sss_clim[:, :, n2]
m.forc_salt_surface[:] = 1.0 / t_rest * (sss - m.salt[:, :, -1, m.tau - 1]) * m.maskt[:, :, -1] * m.dzt[-1]
# apply simple ice mask
mask = np.logical_and(m.temp[:, :, -1, m.tau - 1] * m.maskt[:, :, -1] < -1.8, m.forc_temp_surface < 0.0)
m.forc_temp_surface[mask] = 0.0
m.forc_salt_surface[mask] = 0.0
if m.enable_tempsalt_sources:
m.temp_source[:] = (
m.maskt
* vs.rest_tscl
* (f1 * vs.t_star[:, :, :, n1] + f2 * vs.t_star[:, :, :, n2] - m.temp[:, :, :, m.tau - 1])
)
m.salt_source[:] = (
m.maskt
* vs.rest_tscl
* (f1 * vs.s_star[:, :, :, n1] + f2 * vs.s_star[:, :, :, n2] - m.salt[:, :, :, m.tau - 1])
)
def test_4deg(pyom2_lib):
sim = GlobalFourDegreeTest()
sim.setup()
pyom_obj = load_pyom(pyom2_lib)
pyom_obj = pyom_from_state(
sim.state, pyom_obj, ignore_attrs=("taux", "tauy", "sss_clim", "sst_clim", "qnec", "qnet")
)
sim.run()
run_pyom(pyom_obj, partial(set_forcing_pyom, vs_state=sim.state))
# test passes if differences are less than 0.1% of the maximum value of each variable
compare_state(
sim.state,
pyom_obj,
normalize=True,
rtol=0,
atol=1e-4,
allowed_failures=("Ai_ez", "Ai_nz", "Ai_bx", "Ai_by"),
)
import numpy as np
from veros import VerosSetup, veros_routine
from veros.variables import allocate, Variable
from veros.core.operators import numpy as npx, update, at
from veros.pyom_compat import load_pyom, setup_pyom
from test_base import compare_state
yt_start = -39.0
yt_end = 43
yu_start = -40.0
yu_end = 42
def set_parameter_pyom(pyom_obj):
m = pyom_obj.main_module
(m.nx, m.ny, m.nz) = (30, 42, 15)
m.dt_mom = 4800
m.dt_tracer = 86400 / 2.0
m.runlen = 86400 * 365
m.coord_degree = 1
m.enable_cyclic_x = 1
m.congr_epsilon = 1e-8
m.congr_max_iterations = 10_000
m.ab_eps = 0.1
i = pyom_obj.isoneutral_module
i.enable_neutral_diffusion = 1
i.k_iso_0 = 1000.0
i.k_iso_steep = 500.0
i.iso_dslope = 0.005
i.iso_slopec = 0.01
i.enable_skew_diffusion = 1
m.enable_hor_friction = 1
m.a_h = 2.2e5
m.enable_hor_friction_cos_scaling = 1
m.hor_friction_cospower = 1
m.enable_bottom_friction = 1
m.r_bot = 1e-5
m.enable_streamfunction = True
m.enable_implicit_vert_friction = 1
t = pyom_obj.tke_module
t.enable_tke = 1
t.c_k = 0.1
t.c_eps = 0.7
t.alpha_tke = 30.0
t.mxl_min = 1e-8
t.tke_mxl_choice = 2
t.kappam_min = 2e-4
i.k_gm_0 = 1000.0
e = pyom_obj.eke_module
e.enable_eke = 1
e.eke_k_max = 1e4
e.eke_c_k = 0.4
e.eke_c_eps = 0.5
e.eke_cross = 2.0
e.eke_crhin = 1.0
e.eke_lmin = 100.0
e.enable_eke_superbee_advection = 1
e.enable_eke_isopycnal_diffusion = 1
i = pyom_obj.idemix_module
i.enable_idemix = 1
i.enable_idemix_hor_diffusion = 1
i.enable_eke_diss_surfbot = 1
i.eke_diss_surfbot_frac = 0.2
i.enable_idemix_superbee_advection = 1
i.tau_v = 86400.0
i.jstar = 10.0
i.mu0 = 4.0 / 3.0
i.gamma = 1.57
m.eq_of_state_type = 3
def set_grid_pyom(pyom_obj):
m = pyom_obj.main_module
ddz = [50.0, 70.0, 100.0, 140.0, 190.0, 240.0, 290.0, 340.0, 390.0, 440.0, 490.0, 540.0, 590.0, 640.0, 690.0]
m.dxt[:] = 2.0
m.dyt[:] = 2.0
m.x_origin = 0.0
m.y_origin = -40.0
m.dzt[:] = ddz[::-1]
m.dzt[:] *= 1 / 2.5
def set_coriolis_pyom(pyom_obj):
m = pyom_obj.main_module
m.coriolis_t[:, :] = 2 * m.omega * np.sin(m.yt[None, :] / 180.0 * np.pi)
def set_topography_pyom(pyom_obj):
m = pyom_obj.main_module
(X, Y) = np.meshgrid(m.xt, m.yt)
X = X.transpose()
Y = Y.transpose()
m.kbot[...] = (X > 1.0) | (Y < -20)
def set_initial_conditions_pyom(pyom_obj):
m = pyom_obj.main_module
# initial conditions
m.temp[:, :, :, :] = ((1 - m.zt[None, None, :] / m.zw[0]) * 15 * m.maskt)[..., None]
m.salt[:, :, :, :] = 35.0 * m.maskt[..., None]
# wind stress forcing
taux = np.zeros(m.ny + 1)
yt = m.yt[2 : m.ny + 3]
taux = (0.1e-3 * np.sin(np.pi * (m.yu[2 : m.ny + 3] - yu_start) / (-20.0 - yt_start))) * (yt < -20) + (
0.1e-3 * (1 - np.cos(2 * np.pi * (m.yu[2 : m.ny + 3] - 10.0) / (yu_end - 10.0)))
) * (yt > 10)
m.surface_taux[:, 2 : m.ny + 3] = taux * m.masku[:, 2 : m.ny + 3, -1]
t = pyom_obj.tke_module
t.forc_tke_surface[2:-2, 2:-2] = (
np.sqrt(
(0.5 * (m.surface_taux[2:-2, 2:-2] + m.surface_taux[1:-3, 2:-2])) ** 2
+ (0.5 * (m.surface_tauy[2:-2, 2:-2] + m.surface_tauy[2:-2, 1:-3])) ** 2
)
** 1.5
)
def set_forcing_pyom(pyom_obj):
m = pyom_obj.main_module
t_star = (
15 * np.invert((m.yt < -20) | (m.yt > 20))
+ 15 * (m.yt - yt_start) / (-20 - yt_start) * (m.yt < -20)
+ 15 * (1 - (m.yt - 20) / (yt_end - 20)) * (m.yt > 20.0)
)
t_rest = m.dzt[None, -1] / (30.0 * 86400.0) * m.maskt[:, :, -1]
m.forc_temp_surface = t_rest * (t_star - m.temp[:, :, -1, m.tau - 1])
class ACCSetup(VerosSetup):
@veros_routine
def set_parameter(self, state):
settings = state.settings
settings.identifier = "acc"
settings.nx, settings.ny, settings.nz = 30, 42, 15
settings.dt_mom = 4800
settings.dt_tracer = 86400 / 2.0
settings.runlen = 86400 * 365
settings.x_origin = 0.0
settings.y_origin = -40.0
settings.coord_degree = True
settings.enable_cyclic_x = True
settings.enable_streamfunction = True
settings.enable_neutral_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 500.0
settings.iso_dslope = 0.005
settings.iso_slopec = 0.01
settings.enable_skew_diffusion = True
settings.enable_hor_friction = True
settings.A_h = 2.2e5
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_bottom_friction = True
settings.r_bot = 1e-5
settings.enable_implicit_vert_friction = True
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 2
settings.kappaM_min = 2e-4
settings.K_gm_0 = 1000.0
settings.enable_eke = True
settings.eke_k_max = 1e4
settings.eke_c_k = 0.4
settings.eke_c_eps = 0.5
settings.eke_cross = 2.0
settings.eke_crhin = 1.0
settings.eke_lmin = 100.0
settings.enable_eke_superbee_advection = True
settings.enable_eke_isopycnal_diffusion = True
settings.enable_idemix = 1
settings.enable_idemix_hor_diffusion = 1
settings.enable_eke_diss_surfbot = 1
settings.eke_diss_surfbot_frac = 0.2
settings.enable_idemix_superbee_advection = 1
settings.tau_v = 86400.0
settings.jstar = 10.0
settings.mu0 = 4.0 / 3.0
settings.eq_of_state_type = 3
var_meta = state.var_meta
var_meta.update(
t_star=Variable("t_star", ("yt",), "deg C", "Reference surface temperature"),
t_rest=Variable("t_rest", ("xt", "yt"), "1/s", "Surface temperature restoring time scale"),
)
@veros_routine
def set_grid(self, state):
vs = state.variables
ddz = npx.array(
[50.0, 70.0, 100.0, 140.0, 190.0, 240.0, 290.0, 340.0, 390.0, 440.0, 490.0, 540.0, 590.0, 640.0, 690.0]
)
vs.dxt = update(vs.dxt, at[...], 2.0)
vs.dyt = update(vs.dyt, at[...], 2.0)
vs.dzt = update(vs.dzt, at[...], ddz[::-1] / 2.5)
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[...], 2 * settings.omega * npx.sin(vs.yt[None, :] / 180.0 * settings.pi)
)
@veros_routine
def set_topography(self, state):
vs = state.variables
x, y = npx.meshgrid(vs.xt, vs.yt, indexing="ij")
vs.kbot = npx.logical_or(x > 1.0, y < -20).astype("int")
@veros_routine
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
# initial conditions
vs.temp = update(vs.temp, at[...], ((1 - vs.zt[None, None, :] / vs.zw[0]) * 15 * vs.maskT)[..., None])
vs.salt = update(vs.salt, at[...], 35.0 * vs.maskT[..., None])
# wind stress forcing
taux = allocate(state.dimensions, ("yt",))
taux = npx.where(vs.yt < -20, 0.1e-3 * npx.sin(settings.pi * (vs.yu - yu_start) / (-20.0 - yt_start)), taux)
taux = npx.where(vs.yt > 10, 0.1e-3 * (1 - npx.cos(2 * settings.pi * (vs.yu - 10.0) / (yu_end - 10.0))), taux)
vs.surface_taux = taux * vs.maskU[:, :, -1]
# surface heatflux forcing
vs.t_star = allocate(state.dimensions, ("yt",), fill=15)
vs.t_star = npx.where(vs.yt < -20, 15 * (vs.yt - yt_start) / (-20 - yt_start), vs.t_star)
vs.t_star = npx.where(vs.yt > 20, 15 * (1 - (vs.yt - 20) / (yt_end - 20)), vs.t_star)
vs.t_rest = vs.dzt[npx.newaxis, -1] / (30.0 * 86400.0) * vs.maskT[:, :, -1]
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[2:-2, 2:-2],
npx.sqrt(
(0.5 * (vs.surface_taux[2:-2, 2:-2] + vs.surface_taux[1:-3, 2:-2])) ** 2
+ (0.5 * (vs.surface_tauy[2:-2, 2:-2] + vs.surface_tauy[2:-2, 1:-3])) ** 2
)
** (1.5),
)
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.forc_temp_surface = vs.t_rest * (vs.t_star - vs.temp[:, :, -1, vs.tau])
@veros_routine
def set_diagnostics(self, state):
pass
@veros_routine
def after_timestep(self, state):
pass
def test_acc_setup(pyom2_lib):
pyom_obj = load_pyom(pyom2_lib)
setup_pyom(
pyom_obj,
set_parameter_pyom,
set_grid_pyom,
set_coriolis_pyom,
set_topography_pyom,
set_initial_conditions_pyom,
set_forcing_pyom,
)
sim = ACCSetup()
sim.setup()
# Veros runs a streamfunction solve during setup
allowed_failures = ("p_hydro",)
# psin and line_psin don't quite meet the tolerance
compare_state(sim.state, pyom_obj, rtol=1e-6, allowed_failures=allowed_failures)
import pytest
from veros.routines import veros_routine
from veros.pyom_compat import load_pyom, pyom_from_state, run_pyom
from veros.setups.acc import ACCSetup
from test_base import compare_state
TEST_SETS = {
"standard": dict(),
"pressure": dict(enable_streamfunction=False),
"no-energy-conservation": dict(enable_conserve_energy=False),
}
class ACCTest(ACCSetup):
@veros_routine
def set_parameter(self, state):
settings = state.settings
super().set_parameter(state)
settings.runlen = settings.dt_tracer * 100
settings.restart_output_filename = None
# do not exist in pyOM
settings.kappaH_min = 0.0
settings.enable_kappaH_profile = False
settings.enable_Prandtl_tke = True
@veros_routine
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
super().set_initial_conditions(state)
vs.surface_taux = vs.surface_taux / settings.rho_0
@veros_routine
def set_diagnostics(self, state):
state.diagnostics.clear()
@pytest.mark.parametrize("test_set", TEST_SETS.keys())
def test_acc(pyom2_lib, test_set):
extra_settings = TEST_SETS[test_set]
sim = ACCTest(override=extra_settings)
sim.setup()
pyom_obj = load_pyom(pyom2_lib)
pyom_obj = pyom_from_state(sim.state, pyom_obj, ignore_attrs=("t_star", "t_rest"))
t_rest = sim.state.variables.t_rest
t_star = sim.state.variables.t_star
sim.run()
def set_forcing_pyom(pyom_obj):
m = pyom_obj.main_module
m.forc_temp_surface[:] = t_rest * (t_star - m.temp[:, :, -1, m.tau - 1])
run_pyom(pyom_obj, set_forcing_pyom)
# salt is not used by this setup
allowed_failures = ("salt", "dsalt", "dsalt_vmix", "dsalt_iso")
atol = 1e-8
if test_set == "pressure":
# pressure setups are more numerically sensitive, stick to "observables"
atol = 1e-5
allowed_failures = set(sim.state.variables.fields()) - {"u", "v", "temp", "psi"}
compare_state(
sim.state,
pyom_obj,
atol=atol,
rtol=0,
normalize=True,
allowed_failures=allowed_failures,
)
import numpy as np
from veros.core import advection
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
)
def test_calculate_velocity_on_wgrid(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
advection.calculate_velocity_on_wgrid(vs_state)
pyom_obj.calculate_velocity_on_wgrid()
compare_state(vs_state, pyom_obj)
def test_adv_flux_2nd(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
res = advection.adv_flux_2nd(vs_state, vs_state.variables.Hd[..., 1])
m = pyom_obj.main_module
pyom_obj.adv_flux_2nd(
is_=-1,
ie_=m.nx + 2,
js_=-1,
je_=m.ny + 2,
nz_=m.nz,
adv_fe=m.flux_east,
adv_fn=m.flux_north,
adv_ft=m.flux_top,
var=m.hd[..., 1],
)
np.testing.assert_allclose(res[0], m.flux_east)
np.testing.assert_allclose(res[1], m.flux_north)
np.testing.assert_allclose(res[2], m.flux_top)
def test_adv_flux_superbee(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
res = advection.adv_flux_superbee(vs_state, vs_state.variables.Hd[..., 1])
m = pyom_obj.main_module
pyom_obj.adv_flux_superbee(
is_=-1,
ie_=m.nx + 2,
js_=-1,
je_=m.ny + 2,
nz_=m.nz,
adv_fe=m.flux_east,
adv_fn=m.flux_north,
adv_ft=m.flux_top,
var=m.hd[..., 1],
)
np.testing.assert_allclose(res[0], m.flux_east)
np.testing.assert_allclose(res[1], m.flux_north)
np.testing.assert_allclose(res[2], m.flux_top)
def test_adv_flux_upwind_wgrid(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
res = advection.adv_flux_upwind_wgrid(vs_state, vs_state.variables.Hd[..., 1])
m = pyom_obj.main_module
pyom_obj.adv_flux_upwind_wgrid(
is_=-1,
ie_=m.nx + 2,
js_=-1,
je_=m.ny + 2,
nz_=m.nz,
adv_fe=m.flux_east,
adv_fn=m.flux_north,
adv_ft=m.flux_top,
var=m.hd[..., 1],
)
np.testing.assert_allclose(res[0], m.flux_east)
np.testing.assert_allclose(res[1], m.flux_north)
np.testing.assert_allclose(res[2], m.flux_top)
def test_adv_flux_superbee_wgrid(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
res = advection.adv_flux_superbee_wgrid(vs_state, vs_state.variables.Hd[..., 1])
m = pyom_obj.main_module
pyom_obj.adv_flux_superbee_wgrid(
is_=-1,
ie_=m.nx + 2,
js_=-1,
je_=m.ny + 2,
nz_=m.nz,
adv_fe=m.flux_east,
adv_fn=m.flux_north,
adv_ft=m.flux_top,
var=m.hd[..., 1],
)
np.testing.assert_allclose(res[0], m.flux_east)
np.testing.assert_allclose(res[1], m.flux_north)
np.testing.assert_allclose(res[2], m.flux_top)
import sys
import importlib
import pytest
def pytest_collection_modifyitems(items):
for item in items:
item.add_marker("forked")
@pytest.fixture(autouse=True)
def setup_test():
import veros
from veros.logs import setup_logging
setup_logging(loglevel="warning")
object.__setattr__(veros.runtime_settings, "pyom_compatibility_mode", True)
# reload all core modules to make sure changes take effect
for name, mod in list(sys.modules.items()):
if name.startswith("veros.core"):
importlib.reload(mod)
try:
yield
finally:
object.__setattr__(veros.runtime_settings, "pyom_compatibility_mode", False)
for name, mod in list(sys.modules.items()):
if name.startswith("veros.core"):
importlib.reload(mod)
from veros.core import diffusion
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_cyclic_x=True,
enable_conserve_energy=True,
enable_hor_friction_cos_scaling=True,
enable_tempsalt_sources=True,
K_hbi=1,
K_h=1,
hor_friction_cosPower=2,
)
def prepare_inputs(vs_state, pyom_obj):
# implementations are only identical if non-water values are 0
vs = vs_state.variables
for var in (
"P_diss_sources",
"P_diss_hmix",
):
getattr(pyom_obj.main_module, var.lower())[...] *= vs.maskT
with vs.unlock():
setattr(vs, var, vs.get(var) * vs.maskT)
return vs_state, pyom_obj
def test_tempsalt_biharmonic(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state, pyom_obj = prepare_inputs(vs_state, pyom_obj)
vs_state.variables.update(diffusion.tempsalt_biharmonic(vs_state))
pyom_obj.tempsalt_biharmonic()
compare_state(vs_state, pyom_obj)
def test_tempsalt_diffusion(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state, pyom_obj = prepare_inputs(vs_state, pyom_obj)
vs_state.variables.update(diffusion.tempsalt_diffusion(vs_state))
pyom_obj.tempsalt_diffusion()
compare_state(vs_state, pyom_obj)
def test_tempsalt_sources(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state, pyom_obj = prepare_inputs(vs_state, pyom_obj)
vs_state.variables.update(diffusion.tempsalt_sources(vs_state))
pyom_obj.tempsalt_sources()
compare_state(vs_state, pyom_obj)
from veros.core import eke
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_cyclic_x=True,
enable_eke=True,
enable_TEM_friction=True,
enable_eke_isopycnal_diffusion=True,
enable_store_cabbeling_heat=True,
enable_eke_superbee_advection=True,
enable_eke_upwind_advection=True,
)
def test_set_eke_diffusivities(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(eke.set_eke_diffusivities(vs_state))
pyom_obj.set_eke_diffusivities()
compare_state(vs_state, pyom_obj)
def test_integrate_eke(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(eke.integrate_eke(vs_state))
pyom_obj.integrate_eke()
compare_state(vs_state, pyom_obj)
from veros.core import friction
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_cyclic_x=True,
enable_conserve_energy=True,
enable_bottom_friction_var=True,
enable_hor_friction_cos_scaling=True,
enable_momentum_sources=True,
r_ray=1,
r_bot=1,
r_quad_bot=1,
A_h=1,
A_hbi=1,
)
def test_explicit_vert_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.explicit_vert_friction(vs_state))
pyom_obj.explicit_vert_friction()
compare_state(vs_state, pyom_obj)
def test_implicit_vert_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.implicit_vert_friction(vs_state))
pyom_obj.implicit_vert_friction()
compare_state(vs_state, pyom_obj)
def test_rayleigh_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.rayleigh_friction(vs_state))
pyom_obj.rayleigh_friction()
compare_state(vs_state, pyom_obj)
def test_linear_bottom_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.linear_bottom_friction(vs_state))
pyom_obj.linear_bottom_friction()
compare_state(vs_state, pyom_obj)
def test_quadratic_bottom_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.quadratic_bottom_friction(vs_state))
pyom_obj.quadratic_bottom_friction()
compare_state(vs_state, pyom_obj)
def test_harmonic_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.harmonic_friction(vs_state))
pyom_obj.harmonic_friction()
compare_state(vs_state, pyom_obj)
def test_biharmonic_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.biharmonic_friction(vs_state))
pyom_obj.biharmonic_friction()
compare_state(vs_state, pyom_obj)
def test_momentum_sources(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(friction.momentum_sources(vs_state))
pyom_obj.momentum_sources()
compare_state(vs_state, pyom_obj)
import pytest
from veros.core import idemix
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_idemix=True,
enable_idemix_hor_diffusion=True,
enable_idemix_superbee_advection=True,
enable_idemix_upwind_advection=True,
enable_eke=True,
enable_store_cabbeling_heat=True,
enable_eke_diss_bottom=True,
enable_eke_diss_surfbot=True,
enable_store_bottom_friction_tke=True,
enable_TEM_friction=True,
)
PROBLEM_SETS = {
"eke": dict(enable_eke=True),
"no-eke": dict(enable_eke=False),
"no-eke_diss_bottom": dict(enable_eke_diss_bottom=False),
"no-eke_diss_surfbot": dict(enable_eke_diss_surfbot=False),
}
def test_set_idemix_parameter(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(idemix.set_idemix_parameter(vs_state))
pyom_obj.set_idemix_parameter()
compare_state(vs_state, pyom_obj)
@pytest.mark.parametrize("problem_set", PROBLEM_SETS)
def test_integrate_idemix(pyom2_lib, problem_set):
settings = {**TEST_SETTINGS, **PROBLEM_SETS[problem_set]}
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=settings)
vs_state.variables.update(idemix.integrate_idemix(vs_state))
pyom_obj.integrate_idemix()
compare_state(vs_state, pyom_obj)
import pytest
from veros.core import isoneutral
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_neutral_diffusion=True,
enable_skew_diffusion=True,
enable_TEM_friction=True,
K_iso_steep=1,
)
def test_isoneutral_diffusion_pre(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(isoneutral.isoneutral_diffusion_pre(vs_state))
pyom_obj.isoneutral_diffusion_pre()
compare_state(vs_state, pyom_obj)
def test_isoneutral_diag_streamfunction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(isoneutral.isoneutral_diag_streamfunction(vs_state))
pyom_obj.isoneutral_diag_streamfunction()
compare_state(vs_state, pyom_obj)
@pytest.mark.parametrize("istemp", [True, False])
def test_isoneutral_diffusion(pyom2_lib, istemp):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
m = pyom_obj.main_module
vs = vs_state.variables
vs.update(isoneutral.isoneutral_diffusion(vs_state, vs.temp if istemp else vs.salt, istemp))
pyom_obj.isoneutral_diffusion(
is_=-1, ie_=m.nx + 2, js_=-1, je_=m.ny + 2, nz_=m.nz, tr=m.temp if istemp else m.salt, istemp=istemp
)
compare_state(vs_state, pyom_obj)
@pytest.mark.parametrize("istemp", [True, False])
def test_isoneutral_skew_diffusion(pyom2_lib, istemp):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
m = pyom_obj.main_module
vs = vs_state.variables
vs.update(isoneutral.isoneutral_skew_diffusion(vs_state, vs.temp if istemp else vs.salt, istemp))
pyom_obj.isoneutral_skew_diffusion(
is_=-1, ie_=m.nx + 2, js_=-1, je_=m.ny + 2, nz_=m.nz, tr=m.temp if istemp else m.salt, istemp=istemp
)
compare_state(vs_state, pyom_obj)
def test_isoneutral_friction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(isoneutral.isoneutral_friction(vs_state))
pyom_obj.isoneutral_friction()
compare_state(vs_state, pyom_obj)
from veros.core import momentum
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
coord_degree=True,
enable_cyclic_x=True,
enable_conserve_energy=True,
enable_bottom_friction_var=True,
enable_hor_friction_cos_scaling=True,
enable_implicit_vert_friction=True,
enable_explicit_vert_friction=True,
enable_TEM_friction=True,
enable_hor_friction=True,
enable_biharmonic_friction=True,
enable_ray_friction=True,
enable_bottom_friction=True,
enable_quadratic_bottom_friction=True,
enable_momentum_sources=True,
)
def test_momentum_advection(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(momentum.momentum_advection(vs_state))
pyom_obj.momentum_advection()
# not a part of momentum_advection in PyOM
m = pyom_obj.main_module
m.du[..., m.tau - 1] += m.du_adv
m.dv[..., m.tau - 1] += m.dv_adv
compare_state(vs_state, pyom_obj)
def test_vertical_velocity(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(momentum.vertical_velocity(vs_state))
pyom_obj.vertical_velocity()
compare_state(vs_state, pyom_obj)
def test_momentum(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
# results are only identical if initial guess is already cyclic
from veros.core import utilities
vs = vs_state.variables
m = pyom_obj.main_module
m.psi[...] = utilities.enforce_boundaries(m.psi, vs_state.settings.enable_cyclic_x)
vs.psi = utilities.enforce_boundaries(vs.psi, vs_state.settings.enable_cyclic_x)
vs_state.variables.update(momentum.momentum(vs_state))
pyom_obj.momentum()
compare_state(vs_state, pyom_obj)
from veros.core import numerics
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_cyclic_x=True,
coord_degree=False,
eq_of_state_type=1,
)
def test_calc_grid(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(numerics.calc_grid(vs_state))
pyom_obj.calc_grid()
compare_state(vs_state, pyom_obj)
def test_calc_topo(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(numerics.calc_topo(vs_state))
pyom_obj.calc_topo()
compare_state(vs_state, pyom_obj)
def test_calc_beta(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(numerics.calc_beta(vs_state))
pyom_obj.calc_beta()
compare_state(vs_state, pyom_obj)
def test_calc_initial_conditions(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(numerics.calc_initial_conditions(vs_state))
pyom_obj.calc_initial_conditions()
compare_state(vs_state, pyom_obj)
from veros.core import external, utilities
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=60,
ny=40,
nz=30,
dt_tracer=12000,
dt_mom=3600,
enable_cyclic_x=True,
enable_streamfunction=False,
)
def test_solve_pressure(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs = vs_state.variables
settings = vs_state.settings
# results are only identical if initial guess is already cyclic
m = pyom_obj.main_module
m.psi[...] = utilities.enforce_boundaries(m.psi, settings.enable_cyclic_x)
vs.psi = utilities.enforce_boundaries(vs.psi, settings.enable_cyclic_x)
vs.update(external.solve_pressure(vs_state))
pyom_obj.solve_pressure()
compare_state(vs_state, pyom_obj)
from veros.core import external
from veros.pyom_compat import get_random_state
from test_base import compare_state
TEST_SETTINGS = dict(
nx=70,
ny=60,
nz=50,
dt_tracer=3600,
dt_mom=3600,
enable_cyclic_x=True,
enable_streamfunction=True,
)
def test_solve_streamfunction(pyom2_lib):
vs_state, pyom_obj = get_random_state(pyom2_lib, extra_settings=TEST_SETTINGS)
vs_state.variables.update(external.solve_streamfunction(vs_state))
pyom_obj.solve_streamfunction()
compare_state(vs_state, pyom_obj)
import numpy as np
from textwrap import indent
from veros.variables import remove_ghosts
from veros.pyom_compat import state_from_pyom, VEROS_TO_PYOM_SETTING, VEROS_TO_PYOM_VAR
def _normalize(*arrays):
if any(a.size == 0 for a in arrays):
return arrays
norm = np.nanmax(np.abs(arrays[0]))
if norm == 0.0:
return arrays
return tuple(a / norm for a in arrays)
def compare_state(
vs_state,
pyom_obj,
atol=1e-10,
rtol=1e-8,
include_ghosts=False,
allowed_failures=None,
normalize=False,
):
IGNORE_SETTINGS = ("congr_max_iterations",)
if allowed_failures is None:
allowed_failures = []
pyom_state = state_from_pyom(pyom_obj)
def assert_setting(setting):
vs_val = vs_state.settings.get(setting)
setting = VEROS_TO_PYOM_SETTING.get(setting, setting)
if setting is None or setting in IGNORE_SETTINGS:
return
pyom_val = pyom_state.settings.get(setting)
assert vs_val == pyom_val, (vs_val, pyom_val)
def assert_var(var):
vs_val = vs_state.variables.get(var)
var = VEROS_TO_PYOM_VAR.get(var, var)
if var is None:
return
if var not in pyom_state.variables:
return
pyom_val = pyom_state.variables.get(var)
if var in ("tau", "taup1", "taum1"):
assert pyom_val == vs_val + 1
return
if not include_ghosts:
vs_val = remove_ghosts(vs_val, vs_state.var_meta[var].dims)
pyom_val = remove_ghosts(pyom_val, pyom_state.var_meta[var].dims)
if normalize:
vs_val, pyom_val = _normalize(vs_val, pyom_val)
np.testing.assert_allclose(vs_val, pyom_val, atol=atol, rtol=rtol)
passed = True
for setting in vs_state.settings.fields():
try:
assert_setting(setting)
except AssertionError as exc:
if setting not in allowed_failures:
print(f"{setting}:{indent(str(exc), ' ' * 4)}")
passed = False
for var in vs_state.variables.fields():
try:
assert_var(var)
except AssertionError as exc:
if var not in allowed_failures:
print(f"{var}:{indent(str(exc), ' ' * 4)}")
passed = False
assert passed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment