Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
import os
from veros import logger, runtime_settings, runtime_state
from veros.io_tools import hdf5 as h5tools
from veros.signals import do_not_disturb
from veros.distributed import get_chunk_slices, exchange_overlap
from veros.variables import get_shape
def read_from_h5(dimensions, var_meta, infile, groupname, enable_cyclic_x):
from veros.core.operators import numpy as npx, update, at
variables = {}
for key, var in infile[groupname].items():
if not var_meta[key].dims:
variables[key] = npx.array(var)
continue
local_shape = get_shape(dimensions, var_meta[key].dims, local=True, include_ghosts=True)
gidx, lidx = get_chunk_slices(dimensions["xt"], dimensions["yt"], var_meta[key].dims, include_overlap=True)
# pass dtype as str to prevent endianness from leaking into array
variables[key] = npx.empty(local_shape, dtype=str(var.dtype))
variables[key] = update(variables[key], at[lidx], var[gidx])
variables[key] = exchange_overlap(variables[key], var_meta[key].dims, enable_cyclic_x)
attributes = {key: var.item() for key, var in infile[groupname].attrs.items()}
return attributes, variables
def write_to_h5(dimensions, var_meta, var_data, outfile, groupname, attributes=None):
if attributes is None:
attributes = {}
group = outfile.require_group(groupname)
for key, var in var_data.items():
var_dims = var_meta[key].dims
if var_dims is None:
var_dims = []
global_shape = get_shape(dimensions, var_dims, local=False)
gidx, lidx = get_chunk_slices(dimensions["xt"], dimensions["yt"], var_dims, include_overlap=True)
kwargs = dict(
exact=True,
)
if var_dims:
chunksize = []
for d in var_dims:
if d in dimensions:
chunksize.append(get_shape(dimensions, (d,), local=True, include_ghosts=False)[0])
else:
chunksize.append(1)
kwargs.update(chunks=tuple(chunksize))
if runtime_settings.hdf5_gzip_compression and runtime_state.proc_num == 1:
kwargs.update(compression="gzip", compression_opts=1)
group.require_dataset(key, global_shape, var.dtype, **kwargs)
group[key][gidx] = var[lidx]
for key, val in attributes.items():
group.attrs[key] = val
def read_restart(state):
settings = state.settings
if not settings.restart_input_filename:
return
if runtime_settings.force_overwrite:
raise RuntimeError("To prevent data loss, force_overwrite cannot be used in restart runs")
statedict = dict(state.variables.items())
statedict.update(state.settings.items())
restart_filename = settings.restart_input_filename.format(**statedict)
if not os.path.isfile(restart_filename):
raise IOError(f"restart file {restart_filename} not found")
logger.info(f"Reading restart data from {restart_filename}")
with h5tools.threaded_io(restart_filename, "r") as infile, state.variables.unlock():
# core restart
restart_vars = {var: meta for var, meta in state.var_meta.items() if meta.write_to_restart and meta.active}
_, restart_data = read_from_h5(state.dimensions, restart_vars, infile, "core", settings.enable_cyclic_x)
for key in restart_vars.keys():
try:
var_data = restart_data[key]
except KeyError:
raise RuntimeError(f"No restart data found for variable {key} in {restart_filename}") from None
setattr(state.variables, key, var_data)
# diagnostic restarts
for diag_name, diagnostic in state.diagnostics.items():
if not diagnostic.var_meta:
# nothing to do
continue
dimensions = dict(state.dimensions)
if diagnostic.extra_dimensions:
dimensions.update(diagnostic.extra_dimensions)
restart_vars = {
var: meta for var, meta in diagnostic.var_meta.items() if meta.write_to_restart and meta.active
}
_, restart_data = read_from_h5(dimensions, restart_vars, infile, diag_name, settings.enable_cyclic_x)
for key in restart_vars.keys():
try:
var_data = restart_data[key]
except KeyError:
raise RuntimeError(
f'No restart data found for variable {key} in {restart_filename} (from diagnostic "{diag_name}")'
) from None
setattr(diagnostic.variables, key, var_data)
return state
@do_not_disturb
def write_restart(state, force=False):
vs = state.variables
settings = state.settings
if runtime_settings.diskless_mode:
return
if not settings.restart_output_filename:
return
write_now = force or (
settings.restart_frequency and vs.itt > 0 and vs.time % settings.restart_frequency < settings.dt_tracer
)
if not write_now:
return
statedict = dict(state.variables.items())
statedict.update(state.settings.items())
restart_filename = settings.restart_output_filename.format(**statedict)
logger.info(f"Writing restart file {restart_filename}")
with h5tools.threaded_io(restart_filename, "w") as outfile:
# core restart
vs = state.variables
restart_vars = {var: meta for var, meta in state.var_meta.items() if meta.write_to_restart and meta.active}
restart_data = {var: getattr(vs, var) for var in restart_vars}
write_to_h5(state.dimensions, restart_vars, restart_data, outfile, "core")
# diagnostic restarts
for diag_name, diagnostic in state.diagnostics.items():
if not diagnostic.var_meta:
# nothing to do
continue
dimensions = dict(state.dimensions)
if diagnostic.extra_dimensions:
dimensions.update(diagnostic.extra_dimensions)
restart_vars = {
var: meta for var, meta in diagnostic.var_meta.items() if meta.write_to_restart and meta.active
}
restart_data = {var: getattr(diagnostic.variables, var) for var in restart_vars}
write_to_h5(dimensions, restart_vars, restart_data, outfile, diag_name)
import functools
import inspect
import threading
from contextlib import ExitStack, contextmanager
from veros import logger
from veros.state import VerosState
# stack helpers
class RoutineStack:
def __init__(self):
self.keep_full_stack = False
self._stack = []
self._current_idx = []
@property
def stack_level(self):
return len(self._current_idx)
def append(self, val):
frame = self._stack
for i in self._current_idx:
frame = frame[i][1]
self._current_idx.append(len(frame))
frame.append([val, []])
def pop(self):
frame = self._stack
for i in self._current_idx[:-1]:
frame = frame[i][1]
if self.keep_full_stack:
last_val = frame[-1][0]
else:
last_val = frame.pop()[0]
self._current_idx.pop()
return last_val
# global context
CURRENT_CONTEXT = threading.local()
CURRENT_CONTEXT.is_dist_safe = True
CURRENT_CONTEXT.routine_stack = RoutineStack()
CURRENT_CONTEXT.mpi4jax_token = None
@contextmanager
def nullcontext():
yield
@contextmanager
def enter_routine(name, routine_obj, timer=None, dist_safe=True):
from veros import runtime_state as rst
from veros.distributed import abort
stack = CURRENT_CONTEXT.routine_stack
logger.trace("{}> {}", "-" * stack.stack_level, name)
stack.append(routine_obj)
reset_dist_safe = False
if CURRENT_CONTEXT.is_dist_safe:
if not dist_safe and rst.proc_num > 1:
CURRENT_CONTEXT.is_dist_safe = False
reset_dist_safe = True
timer_ctx = nullcontext() if timer is None else timer
try:
with timer_ctx:
yield
except: # noqa: E722
if reset_dist_safe:
abort()
raise
finally:
if reset_dist_safe:
CURRENT_CONTEXT.is_dist_safe = True
r = stack.pop()
assert r is routine_obj
exec_time = ""
if timer is not None:
exec_time = f"({timer.last_time:.3f}s)"
logger.trace("<{} {} {}", "-" * stack.stack_level, name, exec_time)
# helper functions
def _get_func_name(function):
return f"{inspect.getmodule(function).__name__}:{function.__qualname__}"
def _is_method(function):
if inspect.ismethod(function):
return True
# hack for unbound methods: check if first argument is called "self"
spec = inspect.getfullargspec(function)
return spec.args and spec.args[0] == "self"
# routine
def veros_routine(function=None, *, dist_safe=True, local_variables=()):
"""
.. note::
This decorator should be applied to all functions that access the Veros state object
(even when subclassing :class:`veros.VerosSetup`).
The first argument to the decorated function must be a VerosState instance.
Veros routines cannot return anything. All changes must be applied to the passed state object.
Parameters:
dist_safe (bool): If set to False, all variables specified in local_variables are synced
to the root process before execution and synced back after. This means that the routine
will only be executed on rank 0. Has no effect in non-distributed contexts.
local_variables (Tuple[str]): List of variable names to be synced if dist_safe=False. This
must include all variables retrieved from the state object throughout the routine (inputs
*and* outputs).
Example:
>>> from veros import VerosSetup, veros_routine
>>>
>>> class MyModel(VerosSetup):
>>> @veros_routine
>>> def set_topography(self, state):
>>> vs = state.variables
>>> settings = state.settings
>>> vs.kbot = npx.random.randint(0, settings.nz, size=vs.kbot.shape)
"""
def inner_decorator(function):
narg = 1 if _is_method(function) else 0
num_params = len(inspect.signature(function).parameters)
if narg >= num_params:
raise TypeError("Veros routines must take at least one argument")
routine = VerosRoutine(function, state_argnum=narg, dist_safe=dist_safe, local_variables=local_variables)
routine = functools.wraps(function)(routine)
return routine
if function is not None:
return inner_decorator(function)
return inner_decorator
class VerosRoutine:
"""Do not instantiate directly!"""
def __init__(self, function, dist_safe=True, local_variables=(), state_argnum=0):
if isinstance(local_variables, str):
local_variables = (local_variables,)
self.function = function
self.dist_safe = dist_safe
self.local_variables = local_variables
self.state_argnum = state_argnum
self.name = _get_func_name(self.function)
def __call__(self, *args, **kwargs):
from veros import runtime_state as rst
from veros.state import VerosState, DistSafeVariableWrapper
from veros.core.operators import flush
veros_state = args[self.state_argnum]
if not isinstance(veros_state, VerosState):
raise TypeError(f"Argument {self.state_argnum} to this Veros routine must be a VerosState object")
timer = veros_state.profile_timers[self.name]
with ExitStack() as es:
vars_initialized = veros_state._variables is not None
if vars_initialized:
es.enter_context(veros_state.variables.unlock())
execute = True
restore_vars = False
if not self.dist_safe:
orig_vars = veros_state._variables
if not isinstance(orig_vars, DistSafeVariableWrapper):
restore_vars = True
veros_state._variables = DistSafeVariableWrapper(orig_vars, self.local_variables)
veros_state._variables._gather_variables()
execute = rst.proc_rank == 0
routine_ctx = enter_routine(name=self.name, routine_obj=self, timer=timer, dist_safe=self.dist_safe)
out = None
try:
with routine_ctx:
if execute:
out = self.function(*args, **kwargs)
finally:
if restore_vars:
veros_state._variables._scatter_variables()
veros_state._variables = orig_vars
flush()
if out is not None:
logger.warning(
f"Routine {self.name} returned object of type {type(out)}. Return objects are silently dropped."
)
def __get__(self, instance, _):
return functools.partial(self.__call__, instance)
def __repr__(self):
return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>"
# kernel
def veros_kernel(function=None, *, static_args=()):
"""Decorator that marks a function as a kernel that can be JIT compiled if supported
by the backend.
Kernels cannot modify the Veros state object. Instead, all modifications have to be
returned explicitly.
Parameters:
static_args (Tuple[str]): Names of kernel arguments that should be static.
Example:
>>> from veros import veros_kernel, KernelOutput
>>>
>>> @veros_kernel
>>> def double_psi(state):
>>> vs = state.variables
>>> vs.psi = 2 * vs.psi
>>> return KernelOutput(psi=vs.psi)
"""
def inner_decorator(function):
kernel = VerosKernel(function, static_args=static_args)
kernel = functools.wraps(function)(kernel)
return kernel
if function is not None:
return inner_decorator(function)
return inner_decorator
class VerosKernel:
"""Do not instantiate directly!"""
def __init__(self, function, static_args=()):
"""Do some parameter introspection."""
# make sure function signature is in the form we need
self.name = _get_func_name(function)
self.func_sig = inspect.signature(function)
func_params = self.func_sig.parameters
allowed_param_types = (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
if any(p.kind not in allowed_param_types for p in func_params.values()):
raise ValueError(f"Veros kernels do not support *args, **kwargs, or keyword-only parameters ({self.name})")
# parse static args
if isinstance(static_args, str):
static_args = (static_args,)
func_argnames = list(func_params.keys())
self.static_argnums = []
for static_arg in static_args:
try:
arg_index = func_argnames.index(static_arg)
except ValueError:
raise ValueError(
f'Veros kernel {self.name} has no argument "{static_arg}", but it is given in static_args'
) from None
self.static_argnums.append(arg_index)
self.function = function
def __call__(self, *args, **kwargs):
from veros import runtime_settings, runtime_state
from veros.core.operators import flush
inject_tokens = runtime_settings.backend == "jax" and runtime_state.proc_num > 1
# apply JIT
if runtime_settings.backend == "jax":
import jax
CompiledFunction = type(jax.jit(lambda: None))
if not isinstance(self.function, CompiledFunction):
if inject_tokens:
function = self.function
@functools.wraps(function)
def token_wrapper(*args):
inputs = args[:-1]
token = args[-1]
CURRENT_CONTEXT.mpi4jax_token = token
out = function(*inputs)
token = CURRENT_CONTEXT.mpi4jax_token
return out, token
if CURRENT_CONTEXT.mpi4jax_token is None:
CURRENT_CONTEXT.mpi4jax_token = jax.lax.create_token()
self.function = token_wrapper
self.function = jax.jit(self.function, static_argnums=self.static_argnums)
# JAX only accepts positional args when using static_argnums
# so convert everything to positional for consistency
bound_args = self.func_sig.bind(*args, **kwargs)
bound_args.apply_defaults()
veros_state = None
for argval in bound_args.arguments.values():
if isinstance(argval, VerosState):
veros_state = argval
break
called_with_state = veros_state is not None
# when profiling, make sure all inputs are ready before starting the timer
if runtime_settings.profile_mode:
flush()
if called_with_state:
timer = veros_state.profile_timers[self.name]
else:
timer = None
with ExitStack() as es:
if called_with_state:
es.enter_context(veros_state.variables.unlock())
args = list(bound_args.arguments.values())
if inject_tokens:
args.append(CURRENT_CONTEXT.mpi4jax_token)
with enter_routine(self.name, self, timer):
out = self.function(*args)
if runtime_settings.profile_mode:
flush()
if inject_tokens:
out, token = out
CURRENT_CONTEXT.mpi4jax_token = token
return out
def __repr__(self):
return f"<{self.__class__.__name__} {self.name} at {hex(id(self))}>"
def is_veros_routine(func):
if isinstance(func, functools.partial):
func = func.func
if inspect.ismethod(func):
func = func.__self__
return isinstance(func, VerosRoutine)
import os
from threading import local
from collections import namedtuple
from veros.backend import BACKENDS
from veros.logs import LOGLEVELS
# globals
log_args = local()
log_args.log_all_processes = False
log_args.loglevel = "info"
# MPI helpers
def _default_mpi_comm():
try:
from mpi4py import MPI
except ImportError:
return None
else:
return MPI.COMM_WORLD
# validators
def parse_two_ints(v):
return (int(v[0]), int(v[1]))
def parse_choice(choices, preserve_case=False):
def validate(choice):
if isinstance(choice, str) and not preserve_case:
choice = choice.lower()
if choice not in choices:
raise ValueError(f"must be one of {choices}")
return choice
return validate
def parse_bool(obj):
if not isinstance(obj, str):
return bool(obj)
return obj.lower() in {"1", "true", "on"}
def check_mpi_comm(comm):
if comm is not None:
from mpi4py import MPI
if not isinstance(comm, MPI.Comm):
raise TypeError("mpi_comm must be Comm instance or None")
return comm
def set_loglevel(val):
from veros import logs
log_args.loglevel = parse_choice(LOGLEVELS)(val)
logs.setup_logging(loglevel=log_args.loglevel, log_all_processes=log_args.log_all_processes)
return log_args.loglevel
def set_log_all_processes(val):
from veros import logs
log_args.log_all_processes = parse_bool(val)
logs.setup_logging(loglevel=log_args.loglevel, log_all_processes=log_args.log_all_processes)
return log_args.log_all_processes
DEVICES = ("cpu", "gpu", "tpu")
FLOAT_TYPES = ("float64", "float32")
LINEAR_SOLVERS = ("scipy", "scipy_jax", "petsc", "best")
# settings
RuntimeSetting = namedtuple("RuntimeSetting", ("type", "default", "read_from_env"))
RuntimeSetting.__new__.__defaults__ = (None, None, True)
AVAILABLE_SETTINGS = {
"backend": RuntimeSetting(parse_choice(BACKENDS), "numpy"),
"device": RuntimeSetting(parse_choice(DEVICES), "cpu"),
"float_type": RuntimeSetting(parse_choice(FLOAT_TYPES), "float64"),
"linear_solver": RuntimeSetting(parse_choice(LINEAR_SOLVERS), "best"),
"petsc_options": RuntimeSetting(str, ""),
"monitor_streamfunction_residual": RuntimeSetting(parse_bool, True),
"num_proc": RuntimeSetting(parse_two_ints, (1, 1), read_from_env=False),
"profile_mode": RuntimeSetting(parse_bool, False),
"loglevel": RuntimeSetting(set_loglevel, "info"),
"mpi_comm": RuntimeSetting(check_mpi_comm, _default_mpi_comm(), read_from_env=False),
"log_all_processes": RuntimeSetting(set_log_all_processes, False),
"use_io_threads": RuntimeSetting(parse_bool, False),
"io_timeout": RuntimeSetting(float, 20),
"hdf5_gzip_compression": RuntimeSetting(parse_bool, True),
"force_overwrite": RuntimeSetting(parse_bool, False),
"diskless_mode": RuntimeSetting(parse_bool, False),
"pyom_compatibility_mode": RuntimeSetting(parse_bool, False),
"setup_file": RuntimeSetting(str, None, read_from_env=False),
"use_special_tdma": RuntimeSetting(parse_bool, None),
}
class RuntimeSettings:
__slots__ = ["__locked__", "__setting_types__", "__settings__", *AVAILABLE_SETTINGS.keys()]
def __init__(self, **kwargs):
self.__locked__ = False
self.__setting_types__ = {}
for name, setting in AVAILABLE_SETTINGS.items():
setting_envvar = f"VEROS_{name.upper()}"
if name in kwargs:
val = kwargs[name]
elif setting.read_from_env:
val = os.environ.get(setting_envvar, setting.default)
else:
val = setting.default
self.__setting_types__[name] = setting.type
self.__setattr__(name, val)
self.__settings__ = set(self.__setting_types__.keys())
def update(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
return self
def __setattr__(self, attr, val):
if getattr(self, "__locked__", False):
raise RuntimeError("Runtime settings cannot be modified after import of core modules")
if attr.startswith("_"):
return super().__setattr__(attr, val)
# coerce type
stype = self.__setting_types__.get(attr)
if stype is not None:
try:
val = stype(val)
except (TypeError, ValueError) as e:
raise ValueError(f'Got invalid value for runtime setting "{attr}": {e!s}') from None
return super().__setattr__(attr, val)
def __repr__(self):
setval = ", ".join(f"{key}={repr(getattr(self, key))}" for key in self.__settings__)
return f"{self.__class__.__name__}({setval})"
# state
class RuntimeState:
"""Unifies attributes from various modules in a simple read-only object"""
__slots__ = ()
@property
def proc_rank(self):
from veros import runtime_settings
comm = runtime_settings.mpi_comm
if comm is None:
return 0
return comm.Get_rank()
@property
def proc_num(self):
from veros import runtime_settings
comm = runtime_settings.mpi_comm
if comm is None:
return 1
return comm.Get_size()
@property
def proc_idx(self):
from veros import distributed
return distributed.proc_rank_to_index(self.proc_rank)
@property
def backend_module(self):
from veros import backend, runtime_settings
return backend.get_backend_module(runtime_settings.backend)
def __setattr__(self, attr, val):
raise TypeError(f"Cannot modify {self.__class__.__name__} objects")
from collections import namedtuple
Setting = namedtuple("setting", ("default", "type", "description"))
def optional(type_):
def wrapped(arg):
if arg is None:
return arg
return type_(arg)
return wrapped
PI = 3.14159265358979323846264338327950588
SETTINGS = {
"identifier": Setting("UNNAMED", str, "Identifier of the current simulation"),
"description": Setting("", str, "Description of the current simulation"),
# Model parameters
"nx": Setting(0, int, "Grid points in zonal (x) direction"),
"ny": Setting(0, int, "Grid points in meridional (y,j) direction"),
"nz": Setting(0, int, "Grid points in vertical (z,k) direction"),
"dt_mom": Setting(0.0, float, "Time step in seconds for momentum"),
"dt_tracer": Setting(0.0, float, "Time step for tracers, can be larger than dt_mom"),
"runlen": Setting(0.0, float, "Length of simulation in seconds"),
"AB_eps": Setting(0.1, float, "Deviation from Adam-Bashforth weighting"),
"x_origin": Setting(0, float, "Grid origin in x-direction"),
"y_origin": Setting(0, float, "Grid origin in y-direction"),
# Physical constants
"pi": Setting(PI, float, "Pi"),
"radius": Setting(6370e3, float, "Earth radius in m"),
"degtom": Setting(6370e3 / 180.0 * PI, float, "Conversion factor from degrees latitude to meters"),
"omega": Setting(PI / 43082.0, float, "Earth rotation frequency in 1/s"),
"rho_0": Setting(1024.0, float, "Boussinesq reference density in :math:`kg/m^3`"),
"grav": Setting(9.81, float, "Gravitational constant in :math:`m/s^2`"),
# Logical switches for general model setup
"coord_degree": Setting(False, bool, "either spherical (True) or cartesian (False) coordinates"),
"enable_cyclic_x": Setting(False, bool, "enable cyclic boundary conditions"),
"eq_of_state_type": Setting(1, int, "equation of state: 1: linear, 3: nonlinear with comp., 5: TEOS"),
"enable_implicit_vert_friction": Setting(False, bool, "enable implicit vertical friction"),
"enable_explicit_vert_friction": Setting(False, bool, "enable explicit vertical friction"),
"enable_hor_friction": Setting(False, bool, "enable horizontal friction"),
"enable_hor_diffusion": Setting(False, bool, "enable horizontal diffusion"),
"enable_biharmonic_friction": Setting(False, bool, "enable biharmonic horizontal friction"),
"enable_biharmonic_mixing": Setting(False, bool, "enable biharmonic horizontal mixing"),
"enable_hor_friction_cos_scaling": Setting(False, bool, "scaling of hor. viscosity with cos(latitude)**cosPower"),
"enable_ray_friction": Setting(False, bool, "enable Rayleigh damping"),
"enable_bottom_friction": Setting(False, bool, "enable bottom friction"),
"enable_bottom_friction_var": Setting(False, bool, "enable bottom friction with lateral variations"),
"enable_quadratic_bottom_friction": Setting(False, bool, "enable quadratic bottom friction"),
"enable_tempsalt_sources": Setting(False, bool, "enable restoring zones, etc"),
"enable_momentum_sources": Setting(False, bool, "enable restoring zones, etc"),
"enable_superbee_advection": Setting(False, bool, "enable advection scheme with implicit mixing"),
"enable_conserve_energy": Setting(True, bool, "exchange energy consistently"),
"enable_store_bottom_friction_tke": Setting(
False, bool, "transfer dissipated energy by bottom/rayleig fric. to TKE, else transfer to internal waves"
),
"enable_store_cabbeling_heat": Setting(
False, bool, "transfer non-linear mixing terms to potential enthalpy, else transfer to TKE and EKE"
),
"enable_noslip_lateral": Setting(
False, bool, "enable lateral no-slip boundary conditions in harmonic- and biharmonic friction."
),
"enable_streamfunction": Setting(
True,
bool,
"solve for external mode with barotropic streamfunction, else solve for surface pressure and sea surface height",
),
# Mixing parameters
"A_h": Setting(0.0, float, "lateral viscosity in m^2/s"),
"K_h": Setting(0.0, float, "lateral diffusivity in m^2/s"),
"r_ray": Setting(0.0, float, "Rayleigh damping coefficient in 1/s"),
"r_bot": Setting(0.0, float, "bottom friction coefficient in 1/s"),
"r_quad_bot": Setting(0.0, float, "qudratic bottom friction coefficient"),
"hor_friction_cosPower": Setting(3, float, "power to scale cos term by in horizontal friction"),
"A_hbi": Setting(0.0, float, "lateral biharmonic viscosity in m^4/s"),
"K_hbi": Setting(0.0, float, "lateral biharmonic diffusivity in m^4/s"),
"biharmonic_friction_cosPower": Setting(0, float, "power to scale cos term by in biharmonic friction"),
"kappaH_0": Setting(0.0, float, "fixed values for vertical viscosity/diffusivity which are set for no TKE model"),
"kappaM_0": Setting(0.0, float, "fixed values for vertical viscosity/diffusivity which are set for no TKE model"),
# Options for isopycnal mixing
"enable_neutral_diffusion": Setting(False, bool, "enable isopycnal mixing"),
"enable_skew_diffusion": Setting(False, bool, "enable skew diffusion approach for eddy-driven velocities"),
"enable_TEM_friction": Setting(False, bool, "TEM approach for eddy-driven velocities"),
"K_iso_0": Setting(0.0, float, "constant for isopycnal diffusivity in m^2/s"),
"K_iso_steep": Setting(0.0, float, "lateral diffusivity for steep slopes in m^2/s"),
"K_gm_0": Setting(0.0, float, "fixed value for K_gm which is set for no EKE model"),
"iso_dslope": Setting(0.0008, float, "parameters controlling max allowed isopycnal slopes"),
"iso_slopec": Setting(0.001, float, "parameters controlling max allowed isopycnal slopes"),
# Idemix 1.0
"enable_idemix": Setting(False, bool, ""),
"tau_v": Setting(2.0 * 86400.0, float, "time scale for vertical symmetrisation"),
"tau_h": Setting(15.0 * 86400.0, float, "time scale for horizontal symmetrisation"),
"gamma": Setting(1.57, float, ""),
"jstar": Setting(5.0, float, "spectral bandwidth in modes"),
"mu0": Setting(1.0 / 3.0, float, "dissipation parameter"),
"enable_idemix_hor_diffusion": Setting(False, bool, ""),
"enable_eke_diss_bottom": Setting(False, bool, ""),
"enable_eke_diss_surfbot": Setting(False, bool, ""),
"eke_diss_surfbot_frac": Setting(1.0, float, "fraction which goes into bottom"),
"enable_idemix_superbee_advection": Setting(False, bool, ""),
"enable_idemix_upwind_advection": Setting(False, bool, ""),
# TKE
"enable_tke": Setting(False, bool, ""),
"c_k": Setting(0.1, float, ""),
"c_eps": Setting(0.7, float, ""),
"alpha_tke": Setting(1.0, float, ""),
"mxl_min": Setting(1e-12, float, ""),
"kappaM_min": Setting(0.0, float, ""),
"kappaM_max": Setting(100.0, float, ""),
"tke_mxl_choice": Setting(1, int, ""),
"enable_tke_superbee_advection": Setting(False, bool, ""),
"enable_tke_upwind_advection": Setting(False, bool, ""),
"enable_tke_hor_diffusion": Setting(False, bool, ""),
"K_h_tke": Setting(2000.0, float, "lateral diffusivity for tke"),
# EKE
"enable_eke": Setting(False, bool, ""),
"eke_lmin": Setting(100.0, float, "minimal length scale in m"),
"eke_c_k": Setting(1.0, float, ""),
"eke_cross": Setting(1.0, float, "Parameter for EKE model"),
"eke_crhin": Setting(1.0, float, "Parameter for EKE model"),
"eke_c_eps": Setting(1.0, float, "Parameter for EKE model"),
"eke_k_max": Setting(1e4, float, "maximum of K_gm"),
"alpha_eke": Setting(1.0, float, "factor vertical friction"),
"enable_eke_superbee_advection": Setting(False, bool, ""),
"enable_eke_upwind_advection": Setting(False, bool, ""),
"enable_eke_isopycnal_diffusion": Setting(False, bool, "use K_gm also for isopycnal diffusivity"),
# Restarts
"restart_input_filename": Setting(
None, optional(str), "File name of restart input. If not given, no restart data will be read."
),
"restart_output_filename": Setting(
"{identifier}_{itt:0>4d}.restart.h5",
optional(str),
"File name of restart output. May contain Python format syntax that is substituted with Veros attributes.",
),
"restart_frequency": Setting(0, float, "Frequency (in seconds) to write restart data"),
# New
"kappaH_min": Setting(0.0, float, "minimum value for vertical diffusivity"),
"enable_kappaH_profile": Setting(
False, bool, "Compute vertical profile of diffusivity after Bryan and Lewis (1979) in TKE routine"
),
"enable_Prandtl_tke": Setting(True, bool, "Compute Prandtl number from stratification levels in TKE routine"),
"Prandtl_tke0": Setting(
10.0, float, "Constant Prandtl number when stratification is neglected for kappaH computation in TKE routine"
),
}
def check_setting_conflicts(settings):
if settings.enable_tke and not settings.enable_implicit_vert_friction:
raise RuntimeError(
"use TKE model only with implicit vertical friction (set enable_implicit_vert_fricton to True)"
)
from veros.setups.acc.acc import ACCSetup # noqa: F401
#!/usr/bin/env python
from veros import VerosSetup, veros_routine
from veros.variables import allocate, Variable
from veros.distributed import global_min, global_max
from veros.core.operators import numpy as npx, update, at
class ACCSetup(VerosSetup):
"""A model using spherical coordinates with a partially closed domain representing the Atlantic and ACC.
Wind forcing over the channel part and buoyancy relaxation drive a large-scale meridional overturning circulation.
This setup demonstrates:
- setting up an idealized geometry
- updating surface forcings
- basic usage of diagnostics
`Adapted from pyOM2 <https://wiki.cen.uni-hamburg.de/ifm/TO/pyOM2/ACC%202>`_.
"""
@veros_routine
def set_parameter(self, state):
settings = state.settings
settings.identifier = "acc"
settings.description = "My ACC setup"
settings.nx, settings.ny, settings.nz = 30, 42, 15
settings.dt_mom = 4800
settings.dt_tracer = 86400 / 2.0
settings.runlen = 86400 * 365
settings.x_origin = 0.0
settings.y_origin = -40.0
settings.coord_degree = True
settings.enable_cyclic_x = True
settings.enable_neutral_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 500.0
settings.iso_dslope = 0.005
settings.iso_slopec = 0.01
settings.enable_skew_diffusion = True
settings.enable_hor_friction = True
settings.A_h = (2 * settings.degtom) ** 3 * 2e-11
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_bottom_friction = True
settings.r_bot = 1e-5
settings.enable_implicit_vert_friction = True
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 2
settings.kappaM_min = 2e-4
settings.kappaH_min = 2e-5
settings.enable_kappaH_profile = True
settings.K_gm_0 = 1000.0
settings.enable_eke = True
settings.eke_k_max = 1e4
settings.eke_c_k = 0.4
settings.eke_c_eps = 0.5
settings.eke_cross = 2.0
settings.eke_crhin = 1.0
settings.eke_lmin = 100.0
settings.enable_eke_superbee_advection = True
settings.enable_eke_isopycnal_diffusion = True
settings.enable_idemix = False
settings.eq_of_state_type = 3
var_meta = state.var_meta
var_meta.update(
t_star=Variable("t_star", ("yt",), "deg C", "Reference surface temperature"),
t_rest=Variable("t_rest", ("xt", "yt"), "1/s", "Surface temperature restoring time scale"),
)
@veros_routine
def set_grid(self, state):
vs = state.variables
ddz = npx.array(
[50.0, 70.0, 100.0, 140.0, 190.0, 240.0, 290.0, 340.0, 390.0, 440.0, 490.0, 540.0, 590.0, 640.0, 690.0]
)
vs.dxt = update(vs.dxt, at[...], 2.0)
vs.dyt = update(vs.dyt, at[...], 2.0)
vs.dzt = update(vs.dzt, at[...], ddz[::-1] / 2.5)
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[...], 2 * settings.omega * npx.sin(vs.yt[None, :] / 180.0 * settings.pi)
)
@veros_routine
def set_topography(self, state):
vs = state.variables
x, y = npx.meshgrid(vs.xt, vs.yt, indexing="ij")
vs.kbot = npx.logical_or(x > 1.0, y < -20).astype("int")
@veros_routine
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
# initial conditions
vs.temp = update(vs.temp, at[...], ((1 - vs.zt[None, None, :] / vs.zw[0]) * 15 * vs.maskT)[..., None])
vs.salt = update(vs.salt, at[...], 35.0 * vs.maskT[..., None])
# wind stress forcing
yt_min = global_min(vs.yt.min())
yu_min = global_min(vs.yu.min())
yt_max = global_max(vs.yt.max())
yu_max = global_max(vs.yu.max())
taux = allocate(state.dimensions, ("yt",))
taux = npx.where(vs.yt < -20, 0.1 * npx.sin(settings.pi * (vs.yu - yu_min) / (-20.0 - yt_min)), taux)
taux = npx.where(vs.yt > 10, 0.1 * (1 - npx.cos(2 * settings.pi * (vs.yu - 10.0) / (yu_max - 10.0))), taux)
vs.surface_taux = taux * vs.maskU[:, :, -1]
# surface heatflux forcing
vs.t_star = allocate(state.dimensions, ("yt",), fill=15)
vs.t_star = npx.where(vs.yt < -20, 15 * (vs.yt - yt_min) / (-20 - yt_min), vs.t_star)
vs.t_star = npx.where(vs.yt > 20, 15 * (1 - (vs.yt - 20) / (yt_max - 20)), vs.t_star)
vs.t_rest = vs.dzt[npx.newaxis, -1] / (30.0 * 86400.0) * vs.maskT[:, :, -1]
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[2:-2, 2:-2],
npx.sqrt(
(0.5 * (vs.surface_taux[2:-2, 2:-2] + vs.surface_taux[1:-3, 2:-2]) / settings.rho_0) ** 2
+ (0.5 * (vs.surface_tauy[2:-2, 2:-2] + vs.surface_tauy[2:-2, 1:-3]) / settings.rho_0) ** 2
)
** (1.5),
)
if settings.enable_idemix:
vs.forc_iw_bottom = 1e-6 * vs.maskW[:, :, -1]
vs.forc_iw_surface = 1e-7 * vs.maskW[:, :, -1]
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.forc_temp_surface = vs.t_rest * (vs.t_star - vs.temp[:, :, -1, vs.tau])
@veros_routine
def set_diagnostics(self, state):
settings = state.settings
diagnostics = state.diagnostics
diagnostics["snapshot"].output_frequency = 86400 * 10
diagnostics["averages"].output_variables = (
"salt",
"temp",
"u",
"v",
"w",
"psi",
"surface_taux",
"surface_tauy",
)
diagnostics["averages"].output_frequency = 365 * 86400.0
diagnostics["averages"].sampling_frequency = settings.dt_tracer * 10
diagnostics["overturning"].output_frequency = 365 * 86400.0 / 48.0
diagnostics["overturning"].sampling_frequency = settings.dt_tracer * 10
diagnostics["tracer_monitor"].output_frequency = 365 * 86400.0 / 12.0
diagnostics["energy"].output_frequency = 365 * 86400.0 / 48
diagnostics["energy"].sampling_frequency = settings.dt_tracer * 10
@veros_routine
def after_timestep(self, state):
pass
from veros.setups.acc_basic.acc_basic import ACCBasicSetup # noqa: F401
#!/usr/bin/env python
from veros import VerosSetup, veros_routine
from veros.variables import allocate, Variable
from veros.distributed import global_min, global_max
from veros.core.operators import numpy as npx, update, at
class ACCBasicSetup(VerosSetup):
"""A model using spherical coordinates with a partially closed domain representing the Atlantic and ACC.
Wind forcing over the channel part and buoyancy relaxation drive a large-scale meridional overturning circulation.
This setup demonstrates:
- setting up an idealized geometry
- updating surface forcings
- constant horizontal and vertical mixing (switched off IDEMIX and GM_EKE)
- basic usage of diagnostics
:doc:`Adapted from ACC channel model </reference/setups/acc>`.
"""
@veros_routine
def set_parameter(self, state):
settings = state.settings
settings.identifier = "acc_basic"
settings.description = "My ACC basic setup"
settings.nx, settings.ny, settings.nz = 30, 42, 15
settings.dt_mom = 4800
settings.dt_tracer = 86400 / 2.0
settings.runlen = 86400 * 365 * 20
settings.x_origin = 0.0
settings.y_origin = -40.0
settings.coord_degree = True
settings.enable_cyclic_x = True
settings.enable_neutral_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 500.0
settings.iso_dslope = 0.005
settings.iso_slopec = 0.01
settings.enable_skew_diffusion = True
settings.enable_hor_friction = True
settings.A_h = (2 * settings.degtom) ** 3 * 2e-11
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_bottom_friction = True
settings.r_bot = 1e-5
settings.enable_implicit_vert_friction = True
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 2
settings.kappaM_min = 2e-4
settings.kappaH_min = 2e-5
settings.enable_Prandtl_tke = False
settings.enable_kappaH_profile = True
settings.K_gm_0 = 1000.0
settings.enable_eke = False
settings.enable_idemix = False
settings.eq_of_state_type = 3
var_meta = state.var_meta
var_meta.update(
t_star=Variable("t_star", ("yt",), "deg C", "Reference surface temperature"),
t_rest=Variable("t_rest", ("xt", "yt"), "1/s", "Surface temperature restoring time scale"),
)
@veros_routine
def set_grid(self, state):
vs = state.variables
ddz = npx.array(
[50.0, 70.0, 100.0, 140.0, 190.0, 240.0, 290.0, 340.0, 390.0, 440.0, 490.0, 540.0, 590.0, 640.0, 690.0]
)
vs.dxt = update(vs.dxt, at[...], 2.0)
vs.dyt = update(vs.dyt, at[...], 2.0)
vs.dzt = ddz[::-1] / 2.5
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[:, :], 2 * settings.omega * npx.sin(vs.yt[None, :] / 180.0 * settings.pi)
)
@veros_routine
def set_topography(self, state):
vs = state.variables
x, y = npx.meshgrid(vs.xt, vs.yt, indexing="ij")
vs.kbot = npx.logical_or(x > 1.0, y < -20).astype("int")
@veros_routine
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
# initial conditions
vs.temp = update(vs.temp, at[...], ((1 - vs.zt[None, None, :] / vs.zw[0]) * 15 * vs.maskT)[..., None])
vs.salt = update(vs.salt, at[...], 35.0 * vs.maskT[..., None])
# wind stress forcing
yt_min = global_min(vs.yt.min())
yu_min = global_min(vs.yu.min())
yt_max = global_max(vs.yt.max())
yu_max = global_max(vs.yu.max())
taux = allocate(state.dimensions, ("yt",))
taux = npx.where(vs.yt < -20, 0.1 * npx.sin(settings.pi * (vs.yu - yu_min) / (-20.0 - yt_min)), taux)
taux = npx.where(vs.yt > 10, 0.1 * (1 - npx.cos(2 * settings.pi * (vs.yu - 10.0) / (yu_max - 10.0))), taux)
vs.surface_taux = taux * vs.maskU[:, :, -1]
# surface heatflux forcing
vs.t_star = allocate(state.dimensions, ("yt",), fill=15)
vs.t_star = npx.where(vs.yt < -20, 15 * (vs.yt - yt_min) / (-20 - yt_min), vs.t_star)
vs.t_star = npx.where(vs.yt > 20, 15 * (1 - (vs.yt - 20) / (yt_max - 20)), vs.t_star)
vs.t_rest = vs.dzt[npx.newaxis, -1] / (30.0 * 86400.0) * vs.maskT[:, :, -1]
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[2:-2, 2:-2],
npx.sqrt(
(0.5 * (vs.surface_taux[2:-2, 2:-2] + vs.surface_taux[1:-3, 2:-2]) / settings.rho_0) ** 2
+ (0.5 * (vs.surface_tauy[2:-2, 2:-2] + vs.surface_tauy[2:-2, 1:-3]) / settings.rho_0) ** 2
)
** (1.5),
)
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.forc_temp_surface = vs.t_rest * (vs.t_star - vs.temp[:, :, -1, vs.tau])
@veros_routine
def set_diagnostics(self, state):
settings = state.settings
diagnostics = state.diagnostics
diagnostics["averages"].output_variables = (
"salt",
"temp",
"u",
"v",
"w",
"psi",
"surface_taux",
"surface_tauy",
)
diagnostics["averages"].output_frequency = 365 * 86400.0
diagnostics["averages"].sampling_frequency = settings.dt_tracer * 10
diagnostics["overturning"].output_frequency = 365 * 86400.0 / 48.0
diagnostics["overturning"].sampling_frequency = settings.dt_tracer * 10
diagnostics["tracer_monitor"].output_frequency = 365 * 86400.0 / 12.0
@veros_routine
def after_timestep(self, state):
pass
from veros.setups.global_1deg.global_1deg import GlobalOneDegreeSetup # noqa: F401
{
"forcing": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/global_1deg/forcing_1deg_global.nc",
"md5": "1fc86f88acd820da078c8da5873cfa01"
}
}
import os
import h5netcdf
from veros import VerosSetup, tools, time, veros_routine, veros_kernel, KernelOutput
from veros.variables import Variable, allocate
from veros.core.operators import numpy as npx, update, update_multiply, at
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
DATA_FILES = tools.get_assets("global_1deg", os.path.join(BASE_PATH, "assets.json"))
class GlobalOneDegreeSetup(VerosSetup):
"""Global 1 degree model with 115 vertical levels.
`Adapted from pyOM2 <https://wiki.zmaw.de/ifm/TO/pyOM2/1x1%20global%20model>`_.
"""
@veros_routine
def set_parameter(self, state):
"""
set main parameters
"""
settings = state.settings
settings.identifier = "global_1deg"
settings.description = "My global 1 degree setup"
settings.nx = 360
settings.ny = 160
settings.nz = 115
settings.dt_mom = 1800.0
settings.dt_tracer = 1800.0
settings.runlen = 10 * settings.dt_tracer
settings.x_origin = 91.0
settings.y_origin = -79.0
settings.coord_degree = True
settings.enable_cyclic_x = True
settings.enable_hor_friction = True
settings.A_h = 5e4
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_tempsalt_sources = True
settings.enable_implicit_vert_friction = True
settings.eq_of_state_type = 5
# isoneutral
settings.enable_neutral_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 50.0
settings.iso_dslope = 0.005
settings.iso_slopec = 0.005
settings.enable_skew_diffusion = True
# tke
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 1
settings.kappaM_min = 2e-4
settings.kappaH_min = 2e-5
settings.enable_kappaH_profile = True
settings.enable_tke_superbee_advection = True
# eke
settings.enable_eke = True
settings.eke_k_max = 1e4
settings.eke_c_k = 0.4
settings.eke_c_eps = 0.5
settings.eke_cross = 2.0
settings.eke_crhin = 1.0
settings.eke_lmin = 100.0
settings.enable_eke_superbee_advection = True
settings.enable_eke_isopycnal_diffusion = True
# idemix
settings.enable_idemix = False
settings.enable_eke_diss_surfbot = True
settings.eke_diss_surfbot_frac = 0.2
settings.enable_idemix_superbee_advection = True
settings.enable_idemix_hor_diffusion = True
# custom variables
state.dimensions["nmonths"] = 12
state.var_meta.update(
t_star=Variable("t_star", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
s_star=Variable("s_star", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qnec=Variable("qnec", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qnet=Variable("qnet", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qsol=Variable("qsol", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
divpen_shortwave=Variable("divpen_shortwave", ("zt",), "", "", time_dependent=False),
taux=Variable("taux", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
tauy=Variable("tauy", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
)
def _read_forcing(self, var):
from veros.core.operators import numpy as npx
with h5netcdf.File(DATA_FILES["forcing"], "r") as infile:
var = infile.variables[var]
return npx.asarray(var).T
@veros_routine
def set_grid(self, state):
vs = state.variables
dz_data = self._read_forcing("dz")
vs.dzt = update(vs.dzt, at[...], dz_data[::-1])
vs.dxt = update(vs.dxt, at[...], 1.0)
vs.dyt = update(vs.dyt, at[...], 1.0)
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[...], 2 * settings.omega * npx.sin(vs.yt[npx.newaxis, :] / 180.0 * settings.pi)
)
@veros_routine(dist_safe=False, local_variables=["kbot"])
def set_topography(self, state):
import numpy as onp
vs = state.variables
settings = state.settings
bathymetry_data = self._read_forcing("bathymetry")
salt_data = self._read_forcing("salinity")[:, :, ::-1]
mask_salt = salt_data == 0.0
vs.kbot = update(vs.kbot, at[2:-2, 2:-2], 1 + npx.sum(mask_salt.astype("int"), axis=2))
mask_bathy = bathymetry_data == 0
vs.kbot = update_multiply(vs.kbot, at[2:-2, 2:-2], ~mask_bathy)
vs.kbot = vs.kbot * (vs.kbot < settings.nz)
# close some channels
i, j = onp.indices((settings.nx, settings.ny))
mask_channel = (i >= 207) & (i < 214) & (j < 5) # i = 208,214; j = 1,5
vs.kbot = update_multiply(vs.kbot, at[2:-2, 2:-2], ~mask_channel)
# Aleutian islands
mask_channel = (i == 104) & (j == 134) # i = 105; j = 135
vs.kbot = update_multiply(vs.kbot, at[2:-2, 2:-2], ~mask_channel)
# Engl channel
mask_channel = (i >= 269) & (i < 271) & (j == 130) # i = 270,271; j = 131
vs.kbot = update_multiply(vs.kbot, at[2:-2, 2:-2], ~mask_channel)
@veros_routine(
dist_safe=False,
local_variables=[
"t_star",
"s_star",
"qnec",
"qnet",
"qsol",
"divpen_shortwave",
"taux",
"tauy",
"temp",
"salt",
"forc_iw_bottom",
"forc_iw_surface",
"kbot",
"maskT",
"maskW",
"zw",
"dzt",
],
)
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
rpart_shortwave = 0.58
efold1_shortwave = 0.35
efold2_shortwave = 23.0
# initial conditions
temp_data = self._read_forcing("temperature")
vs.temp = update(vs.temp, at[2:-2, 2:-2, :, 0], temp_data[..., ::-1] * vs.maskT[2:-2, 2:-2, :])
vs.temp = update(vs.temp, at[2:-2, 2:-2, :, 1], temp_data[..., ::-1] * vs.maskT[2:-2, 2:-2, :])
salt_data = self._read_forcing("salinity")
vs.salt = update(vs.salt, at[2:-2, 2:-2, :, 0], salt_data[..., ::-1] * vs.maskT[2:-2, 2:-2, :])
vs.salt = update(vs.salt, at[2:-2, 2:-2, :, 1], salt_data[..., ::-1] * vs.maskT[2:-2, 2:-2, :])
# wind stress on MIT grid
vs.taux = update(vs.taux, at[2:-2, 2:-2, :], self._read_forcing("tau_x"))
vs.tauy = update(vs.tauy, at[2:-2, 2:-2, :], self._read_forcing("tau_y"))
qnet_data = self._read_forcing("q_net")
vs.qnet = update(vs.qnet, at[2:-2, 2:-2, :], -qnet_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
qnec_data = self._read_forcing("dqdt")
vs.qnec = update(vs.qnec, at[2:-2, 2:-2, :], qnec_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
qsol_data = self._read_forcing("swf")
vs.qsol = update(vs.qsol, at[2:-2, 2:-2, :], -qsol_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
# SST and SSS
sst_data = self._read_forcing("sst")
vs.t_star = update(vs.t_star, at[2:-2, 2:-2, :], sst_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
sss_data = self._read_forcing("sss")
vs.s_star = update(vs.s_star, at[2:-2, 2:-2, :], sss_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
if settings.enable_idemix:
tidal_energy_data = self._read_forcing("tidal_energy")
mask = (
npx.maximum(0, vs.kbot[2:-2, 2:-2] - 1)[:, :, npx.newaxis]
== npx.arange(settings.nz)[npx.newaxis, npx.newaxis, :]
)
tidal_energy_data *= vs.maskW[2:-2, 2:-2, :][mask].reshape(settings.nx, settings.ny) / settings.rho_0
vs.forc_iw_bottom = update(vs.forc_iw_bottom, at[2:-2, 2:-2], tidal_energy_data)
wind_energy_data = self._read_forcing("wind_energy")
wind_energy_data *= vs.maskW[2:-2, 2:-2, -1] / settings.rho_0 * 0.2
vs.forc_iw_surface = update(vs.forc_iw_surface, at[2:-2, 2:-2], wind_energy_data)
"""
Initialize penetration profile for solar radiation and store divergence in divpen
note that pen is set to 0.0 at the surface instead of 1.0 to compensate for the
shortwave part of the total surface flux
"""
swarg1 = vs.zw / efold1_shortwave
swarg2 = vs.zw / efold2_shortwave
pen = rpart_shortwave * npx.exp(swarg1) + (1.0 - rpart_shortwave) * npx.exp(swarg2)
pen = update(pen, at[-1], 0.0)
vs.divpen_shortwave = allocate(state.dimensions, ("zt",))
vs.divpen_shortwave = update(vs.divpen_shortwave, at[1:], (pen[1:] - pen[:-1]) / vs.dzt[1:])
vs.divpen_shortwave = update(vs.divpen_shortwave, at[0], pen[0] / vs.dzt[0])
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.update(set_forcing_kernel(state))
@veros_routine
def set_diagnostics(self, state):
settings = state.settings
average_vars = [
"surface_taux",
"surface_tauy",
"forc_temp_surface",
"forc_salt_surface",
"psi",
"temp",
"salt",
"u",
"v",
"w",
"Nsqr",
"Hd",
"rho",
"K_diss_v",
"P_diss_v",
"P_diss_nonlin",
"P_diss_iso",
"kappaH",
]
if settings.enable_skew_diffusion:
average_vars += ["B1_gm", "B2_gm"]
if settings.enable_TEM_friction:
average_vars += ["kappa_gm", "K_diss_gm"]
if settings.enable_tke:
average_vars += ["tke", "Prandtlnumber", "mxl", "tke_diss", "forc_tke_surface", "tke_surf_corr"]
if settings.enable_idemix:
average_vars += ["E_iw", "forc_iw_surface", "forc_iw_bottom", "iw_diss", "c0", "v0"]
if settings.enable_eke:
average_vars += ["eke", "K_gm", "L_rossby", "L_rhines"]
state.diagnostics["averages"].output_variables = average_vars
state.diagnostics["cfl_monitor"].output_frequency = 86400.0
state.diagnostics["snapshot"].output_frequency = 365 * 86400 / 24.0
state.diagnostics["overturning"].output_frequency = 365 * 86400
state.diagnostics["overturning"].sampling_frequency = 365 * 86400 / 24.0
state.diagnostics["energy"].output_frequency = 365 * 86400
state.diagnostics["energy"].sampling_frequency = 365 * 86400 / 24.0
state.diagnostics["averages"].output_frequency = 365 * 86400
state.diagnostics["averages"].sampling_frequency = 365 * 86400 / 24.0
@veros_routine
def after_timestep(self, state):
pass
@veros_kernel
def set_forcing_kernel(state):
vs = state.variables
settings = state.settings
t_rest = 30.0 * 86400.0
cp_0 = 3991.86795711963 # J/kg /K
year_in_seconds = time.convert_time(1.0, "years", "seconds")
(n1, f1), (n2, f2) = tools.get_periodic_interval(vs.time, year_in_seconds, year_in_seconds / 12.0, 12)
# linearly interpolate wind stress and shift from MITgcm U/V grid to this grid
vs.surface_taux = update(vs.surface_taux, at[:-1, :], f1 * vs.taux[1:, :, n1] + f2 * vs.taux[1:, :, n2])
vs.surface_tauy = update(vs.surface_tauy, at[:, :-1], f1 * vs.tauy[:, 1:, n1] + f2 * vs.tauy[:, 1:, n2])
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[1:-1, 1:-1],
npx.sqrt(
(0.5 * (vs.surface_taux[1:-1, 1:-1] + vs.surface_taux[:-2, 1:-1]) / settings.rho_0) ** 2
+ (0.5 * (vs.surface_tauy[1:-1, 1:-1] + vs.surface_tauy[1:-1, :-2]) / settings.rho_0) ** 2
)
** (3.0 / 2.0),
)
# W/m^2 K kg/J m^3/kg = K m/s
t_star_cur = f1 * vs.t_star[..., n1] + f2 * vs.t_star[..., n2]
qqnec = f1 * vs.qnec[..., n1] + f2 * vs.qnec[..., n2]
qqnet = f1 * vs.qnet[..., n1] + f2 * vs.qnet[..., n2]
vs.forc_temp_surface = (
(qqnet + qqnec * (t_star_cur - vs.temp[..., -1, vs.tau])) * vs.maskT[..., -1] / cp_0 / settings.rho_0
)
s_star_cur = f1 * vs.s_star[..., n1] + f2 * vs.s_star[..., n2]
vs.forc_salt_surface = 1.0 / t_rest * (s_star_cur - vs.salt[..., -1, vs.tau]) * vs.maskT[..., -1] * vs.dzt[-1]
# apply simple ice mask
mask1 = vs.temp[:, :, -1, vs.tau] * vs.maskT[:, :, -1] > -1.8
mask2 = vs.forc_temp_surface > 0
ice = npx.logical_or(mask1, mask2)
vs.forc_temp_surface *= ice
vs.forc_salt_surface *= ice
# solar radiation
if settings.enable_tempsalt_sources:
vs.temp_source = (
(f1 * vs.qsol[..., n1, None] + f2 * vs.qsol[..., n2, None])
* vs.divpen_shortwave[None, None, :]
* ice[..., None]
* vs.maskT[..., :]
/ cp_0
/ settings.rho_0
)
return KernelOutput(
surface_taux=vs.surface_taux,
surface_tauy=vs.surface_tauy,
temp_source=vs.temp_source,
forc_tke_surface=vs.forc_tke_surface,
forc_temp_surface=vs.forc_temp_surface,
forc_salt_surface=vs.forc_salt_surface,
)
from veros.setups.global_4deg.global_4deg import GlobalFourDegreeSetup # noqa: F401
{
"forcing": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/global_4deg/forcing_4deg_global_open_itf.nc",
"md5": "cfcc6d8cde8da5a74ecec00309d92dd7"
},
"ecmwf": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/global_4deg/ecmwf_4deg_monthly_nc4.nc",
"md5": "d1b4e0e199d7a5883cf7c88d3d6bcb27"
}
}
import os
import h5netcdf
import veros.tools
from veros import VerosSetup, veros_routine, veros_kernel, KernelOutput, logger
from veros.variables import Variable
from veros.core.operators import numpy as npx, update, at
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
DATA_FILES = veros.tools.get_assets("global_4deg", os.path.join(BASE_PATH, "assets.json"))
class GlobalFourDegreeSetup(VerosSetup):
"""Global 4 degree model with 15 vertical levels.
This setup demonstrates:
- setting up a realistic model
- reading input data from external files
- including Indonesian throughflow
- implementing surface forcings
- applying a simple ice mask
`Adapted from pyOM2 <https://wiki.cen.uni-hamburg.de/ifm/TO/pyOM2/4x4%20global%20model>`_.
ChangeLog
- 07-05-2020: modify bathymetry in order to include Indonesian throughflow;
courtesy of Franka Jesse, Utrecht University
"""
@veros_routine
def set_parameter(self, state):
settings = state.settings
settings.identifier = "global_4deg"
settings.description = "My global 4 degree setup"
settings.nx, settings.ny, settings.nz = 90, 40, 15
settings.dt_mom = 1800.0
settings.dt_tracer = 86400.0
settings.runlen = 0.0
settings.x_origin = 4.0
settings.y_origin = -76.0
settings.coord_degree = True
settings.enable_cyclic_x = True
settings.enable_neutral_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 1000.0
settings.iso_dslope = 4.0 / 1000.0
settings.iso_slopec = 1.0 / 1000.0
settings.enable_skew_diffusion = True
settings.enable_hor_friction = True
settings.A_h = (4 * settings.degtom) ** 3 * 2e-11
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_implicit_vert_friction = True
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 2
settings.kappaM_min = 2e-4
settings.kappaH_min = 2e-5
settings.enable_kappaH_profile = True
settings.enable_tke_superbee_advection = True
settings.enable_eke = True
settings.eke_k_max = 1e4
settings.eke_c_k = 0.4
settings.eke_c_eps = 0.5
settings.eke_cross = 2.0
settings.eke_crhin = 1.0
settings.eke_lmin = 100.0
settings.enable_eke_superbee_advection = True
settings.enable_idemix = False
settings.enable_idemix_hor_diffusion = True
settings.enable_eke_diss_surfbot = True
settings.eke_diss_surfbot_frac = 0.2
settings.enable_idemix_superbee_advection = True
settings.eq_of_state_type = 5
# custom variables
state.dimensions["nmonths"] = 12
state.var_meta.update(
sss_clim=Variable("sss_clim", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
sst_clim=Variable("sst_clim", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qnec=Variable("qnec", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qnet=Variable("qnet", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
taux=Variable("taux", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
tauy=Variable("tauy", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
)
def _read_forcing(self, var):
with h5netcdf.File(DATA_FILES["forcing"], "r") as infile:
var_obj = infile.variables[var]
return npx.array(var_obj).T
@veros_routine
def set_grid(self, state):
vs = state.variables
ddz = npx.array(
[50.0, 70.0, 100.0, 140.0, 190.0, 240.0, 290.0, 340.0, 390.0, 440.0, 490.0, 540.0, 590.0, 640.0, 690.0]
)
vs.dzt = ddz[::-1]
vs.dxt = 4.0 * npx.ones_like(vs.dxt)
vs.dyt = 4.0 * npx.ones_like(vs.dyt)
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[...], 2 * settings.omega * npx.sin(vs.yt[npx.newaxis, :] / 180.0 * settings.pi)
)
@veros_routine(dist_safe=False, local_variables=["kbot", "zt"])
def set_topography(self, state):
vs = state.variables
settings = state.settings
bathymetry_data = self._read_forcing("bathymetry")
salt_data = self._read_forcing("salinity")[:, :, ::-1]
land_mask = (vs.zt[npx.newaxis, npx.newaxis, :] <= bathymetry_data[..., npx.newaxis]) | (salt_data == 0.0)
vs.kbot = update(vs.kbot, at[2:-2, 2:-2], 1 + npx.sum(land_mask.astype("int"), axis=2))
# set all-land cells
all_land_mask = (bathymetry_data == 0) | (vs.kbot[2:-2, 2:-2] == settings.nz)
vs.kbot = update(vs.kbot, at[2:-2, 2:-2], npx.where(all_land_mask, 0, vs.kbot[2:-2, 2:-2]))
@veros_routine(
dist_safe=False,
local_variables=[
"taux",
"tauy",
"qnec",
"qnet",
"sss_clim",
"sst_clim",
"temp",
"salt",
"area_t",
"maskT",
"forc_iw_bottom",
"forc_iw_surface",
],
)
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
# initial conditions for T and S
temp_data = self._read_forcing("temperature")[:, :, ::-1]
vs.temp = update(
vs.temp, at[2:-2, 2:-2, :, :2], temp_data[:, :, :, npx.newaxis] * vs.maskT[2:-2, 2:-2, :, npx.newaxis]
)
salt_data = self._read_forcing("salinity")[:, :, ::-1]
vs.salt = update(
vs.salt, at[2:-2, 2:-2, :, :2], salt_data[..., npx.newaxis] * vs.maskT[2:-2, 2:-2, :, npx.newaxis]
)
# use Trenberth wind stress from MITgcm instead of ECMWF (also contained in ecmwf_4deg.cdf)
vs.taux = update(vs.taux, at[2:-2, 2:-2, :], self._read_forcing("tau_x"))
vs.tauy = update(vs.tauy, at[2:-2, 2:-2, :], self._read_forcing("tau_y"))
# heat flux
with h5netcdf.File(DATA_FILES["ecmwf"], "r") as ecmwf_data:
qnec_var = ecmwf_data.variables["Q3"]
vs.qnec = update(vs.qnec, at[2:-2, 2:-2, :], npx.array(qnec_var).T)
vs.qnec = npx.where(vs.qnec <= -1e10, 0.0, vs.qnec)
q = self._read_forcing("q_net")
vs.qnet = update(vs.qnet, at[2:-2, 2:-2, :], -q)
vs.qnet = npx.where(vs.qnet <= -1e10, 0.0, vs.qnet)
mean_flux = (
npx.sum(vs.qnet[2:-2, 2:-2, :] * vs.area_t[2:-2, 2:-2, npx.newaxis]) / 12 / npx.sum(vs.area_t[2:-2, 2:-2])
)
logger.info(" removing an annual mean heat flux imbalance of %e W/m^2" % mean_flux)
vs.qnet = (vs.qnet - mean_flux) * vs.maskT[:, :, -1, npx.newaxis]
# SST and SSS
vs.sst_clim = update(vs.sst_clim, at[2:-2, 2:-2, :], self._read_forcing("sst"))
vs.sss_clim = update(vs.sss_clim, at[2:-2, 2:-2, :], self._read_forcing("sss"))
if settings.enable_idemix:
vs.forc_iw_bottom = update(
vs.forc_iw_bottom, at[2:-2, 2:-2], self._read_forcing("tidal_energy") / settings.rho_0
)
vs.forc_iw_surface = update(
vs.forc_iw_surface, at[2:-2, 2:-2], self._read_forcing("wind_energy") / settings.rho_0 * 0.2
)
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.update(set_forcing_kernel(state))
@veros_routine
def set_diagnostics(self, state):
settings = state.settings
state.diagnostics["snapshot"].output_frequency = 360 * 86400.0
state.diagnostics["overturning"].output_frequency = 360 * 86400.0
state.diagnostics["overturning"].sampling_frequency = settings.dt_tracer
state.diagnostics["energy"].output_frequency = 360 * 86400.0
state.diagnostics["energy"].sampling_frequency = 86400
average_vars = ["temp", "salt", "u", "v", "w", "surface_taux", "surface_tauy", "psi"]
state.diagnostics["averages"].output_variables = average_vars
state.diagnostics["averages"].output_frequency = 360 * 86400.0
state.diagnostics["averages"].sampling_frequency = 86400
@veros_routine
def after_timestep(self, state):
pass
@veros_kernel
def set_forcing_kernel(state):
vs = state.variables
settings = state.settings
year_in_seconds = 360 * 86400.0
(n1, f1), (n2, f2) = veros.tools.get_periodic_interval(vs.time, year_in_seconds, year_in_seconds / 12.0, 12)
# wind stress
vs.surface_taux = f1 * vs.taux[:, :, n1] + f2 * vs.taux[:, :, n2]
vs.surface_tauy = f1 * vs.tauy[:, :, n1] + f2 * vs.tauy[:, :, n2]
# tke flux
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[1:-1, 1:-1],
npx.sqrt(
(0.5 * (vs.surface_taux[1:-1, 1:-1] + vs.surface_taux[:-2, 1:-1]) / settings.rho_0) ** 2
+ (0.5 * (vs.surface_tauy[1:-1, 1:-1] + vs.surface_tauy[1:-1, :-2]) / settings.rho_0) ** 2
)
** 1.5,
)
# heat flux : W/m^2 K kg/J m^3/kg = K m/s
cp_0 = 3991.86795711963
sst = f1 * vs.sst_clim[:, :, n1] + f2 * vs.sst_clim[:, :, n2]
qnec = f1 * vs.qnec[:, :, n1] + f2 * vs.qnec[:, :, n2]
qnet = f1 * vs.qnet[:, :, n1] + f2 * vs.qnet[:, :, n2]
vs.forc_temp_surface = (
(qnet + qnec * (sst - vs.temp[:, :, -1, vs.tau])) * vs.maskT[:, :, -1] / cp_0 / settings.rho_0
)
# salinity restoring
t_rest = 30 * 86400.0
sss = f1 * vs.sss_clim[:, :, n1] + f2 * vs.sss_clim[:, :, n2]
vs.forc_salt_surface = 1.0 / t_rest * (sss - vs.salt[:, :, -1, vs.tau]) * vs.maskT[:, :, -1] * vs.dzt[-1]
# apply simple ice mask
mask = npx.logical_and(vs.temp[:, :, -1, vs.tau] * vs.maskT[:, :, -1] < -1.8, vs.forc_temp_surface < 0.0)
vs.forc_temp_surface = npx.where(mask, 0.0, vs.forc_temp_surface)
vs.forc_salt_surface = npx.where(mask, 0.0, vs.forc_salt_surface)
return KernelOutput(
surface_taux=vs.surface_taux,
surface_tauy=vs.surface_tauy,
forc_tke_surface=vs.forc_tke_surface,
forc_temp_surface=vs.forc_temp_surface,
forc_salt_surface=vs.forc_salt_surface,
)
from veros.setups.global_flexible.global_flexible import GlobalFlexibleResolutionSetup # noqa: F401
{
"topography": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/wave_propagation/ETOPO5_Ice_g_gmt4.nc",
"md5": "12853ace540766bdf11f6f73655cb63a"
},
"forcing": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/global_flexible/forcing_1deg_global_interpolated.nc",
"md5": "902a6cd90c2d814dd4e0704d39013981"
}
}
#!/usr/bin/env python
import os
import h5netcdf
import scipy.ndimage
from veros import veros_routine, veros_kernel, KernelOutput, VerosSetup, runtime_settings as rs, runtime_state as rst
from veros.variables import Variable, allocate
from veros.core.utilities import enforce_boundaries
from veros.core.operators import numpy as npx, update, at
import veros.tools
import veros.time
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
DATA_FILES = veros.tools.get_assets("global_flexible", os.path.join(BASE_PATH, "assets.json"))
class GlobalFlexibleResolutionSetup(VerosSetup):
"""
Global model with flexible resolution.
"""
# global settings
min_depth = 10.0
max_depth = 5400.0
equatorial_grid_spacing_factor = 0.5
polar_grid_spacing_factor = None
@veros_routine
def set_parameter(self, state):
settings = state.settings
settings.identifier = "global_flexible"
settings.description = "Global model with flexible resolution"
settings.nx = 360
settings.ny = 160
settings.nz = 60
settings.dt_mom = settings.dt_tracer = 900
settings.runlen = 86400 * 10
settings.x_origin = 90.0
settings.y_origin = -80.0
settings.coord_degree = True
settings.enable_cyclic_x = True
# friction
settings.enable_hor_friction = True
settings.A_h = 5e4
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_tempsalt_sources = True
settings.enable_implicit_vert_friction = True
settings.eq_of_state_type = 5
# isoneutral
settings.enable_neutral_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 50.0
settings.iso_dslope = 0.005
settings.iso_slopec = 0.005
settings.enable_skew_diffusion = True
# tke
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 2
settings.kappaM_min = 2e-4
settings.kappaH_min = 2e-5
settings.enable_kappaH_profile = True
settings.enable_tke_superbee_advection = True
# eke
settings.enable_eke = True
settings.eke_k_max = 1e4
settings.eke_c_k = 0.4
settings.eke_c_eps = 0.5
settings.eke_cross = 2.0
settings.eke_crhin = 1.0
settings.eke_lmin = 100.0
settings.enable_eke_superbee_advection = True
settings.enable_eke_isopycnal_diffusion = True
# idemix
settings.enable_idemix = False
settings.enable_eke_diss_surfbot = True
settings.eke_diss_surfbot_frac = 0.2
settings.enable_idemix_superbee_advection = True
settings.enable_idemix_hor_diffusion = True
# custom variables
state.dimensions["nmonths"] = 12
state.var_meta.update(
t_star=Variable("t_star", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
s_star=Variable("s_star", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qnec=Variable("qnec", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qnet=Variable("qnet", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
qsol=Variable("qsol", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
divpen_shortwave=Variable("divpen_shortwave", ("zt",), "", "", time_dependent=False),
taux=Variable("taux", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
tauy=Variable("tauy", ("xt", "yt", "nmonths"), "", "", time_dependent=False),
)
def _get_data(self, var, idx=None):
if idx is None:
idx = Ellipsis
else:
idx = idx[::-1]
kwargs = {}
if rst.proc_num > 1:
kwargs.update(
driver="mpio",
comm=rs.mpi_comm,
)
with h5netcdf.File(DATA_FILES["forcing"], "r", **kwargs) as forcing_file:
var_obj = forcing_file.variables[var]
return npx.array(var_obj[idx]).T
@veros_routine(dist_safe=False, local_variables=["dxt", "dyt", "dzt"])
def set_grid(self, state):
vs = state.variables
settings = state.settings
if settings.ny % 2:
raise ValueError("ny has to be an even number of grid cells")
vs.dxt = update(vs.dxt, at[...], 360.0 / settings.nx)
if self.equatorial_grid_spacing_factor is not None:
eq_spacing = self.equatorial_grid_spacing_factor * 160.0 / settings.ny
else:
eq_spacing = None
if self.polar_grid_spacing_factor is not None:
polar_spacing = self.polar_grid_spacing_factor * 160.0 / settings.ny
else:
polar_spacing = None
vs.dyt = update(
vs.dyt,
at[2:-2],
veros.tools.get_vinokur_grid_steps(
settings.ny, 160.0, eq_spacing, upper_stepsize=polar_spacing, two_sided_grid=True
),
)
vs.dzt = veros.tools.get_vinokur_grid_steps(settings.nz, self.max_depth, self.min_depth, refine_towards="lower")
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[...], 2 * settings.omega * npx.sin(vs.yt[npx.newaxis, :] / 180.0 * settings.pi)
)
def _shift_longitude_array(self, vs, lon, arr):
wrap_i = npx.where((lon[:-1] < vs.xt.min()) & (lon[1:] >= vs.xt.min()))[0][0]
new_lon = npx.concatenate((lon[wrap_i:-1], lon[:wrap_i] + 360.0))
new_arr = npx.concatenate((arr[wrap_i:-1, ...], arr[:wrap_i, ...]))
return new_lon, new_arr
@veros_routine(dist_safe=False, local_variables=["kbot", "xt", "yt", "zt"])
def set_topography(self, state):
vs = state.variables
settings = state.settings
with h5netcdf.File(DATA_FILES["topography"], "r") as topography_file:
topo_x, topo_y, topo_z = (npx.array(topography_file.variables[k], dtype="float").T for k in ("x", "y", "z"))
topo_z = npx.minimum(topo_z, 0.0)
# smooth topography to match grid resolution
gaussian_sigma = (0.5 * len(topo_x) / settings.nx, 0.5 * len(topo_y) / settings.ny)
topo_z_smoothed = scipy.ndimage.gaussian_filter(topo_z, sigma=gaussian_sigma)
topo_z_smoothed = npx.where(topo_z >= -1, 0, topo_z_smoothed)
topo_x_shifted, topo_z_shifted = self._shift_longitude_array(vs, topo_x, topo_z_smoothed)
coords = (vs.xt[2:-2], vs.yt[2:-2])
z_interp = allocate(state.dimensions, ("xt", "yt"), local=False)
z_interp = update(
z_interp,
at[2:-2, 2:-2],
veros.tools.interpolate((topo_x_shifted, topo_y), topo_z_shifted, coords, kind="nearest", fill=False),
)
depth_levels = 1 + npx.argmin(npx.abs(z_interp[:, :, npx.newaxis] - vs.zt[npx.newaxis, npx.newaxis, :]), axis=2)
vs.kbot = update(vs.kbot, at[2:-2, 2:-2], npx.where(z_interp < 0.0, depth_levels, 0)[2:-2, 2:-2])
vs.kbot = npx.where(vs.kbot < settings.nz, vs.kbot, 0)
vs.kbot = enforce_boundaries(vs.kbot, settings.enable_cyclic_x, local=True)
# remove marginal seas
# (dilate to close 1-cell passages, fill holes, undo dilation)
marginal = scipy.ndimage.binary_erosion(
scipy.ndimage.binary_fill_holes(scipy.ndimage.binary_dilation(vs.kbot == 0))
)
vs.kbot = npx.where(marginal, 0, vs.kbot)
@veros_routine
def set_initial_conditions(self, state):
vs = state.variables
settings = state.settings
rpart_shortwave = 0.58
efold1_shortwave = 0.35
efold2_shortwave = 23.0
t_grid = (vs.xt[2:-2], vs.yt[2:-2], vs.zt)
xt_forc, yt_forc, zt_forc = (self._get_data(k) for k in ("xt", "yt", "zt"))
zt_forc = zt_forc[::-1]
# coordinates must be monotonous for this to work
assert npx.diff(xt_forc).all() > 0
assert npx.diff(yt_forc).all() > 0
# determine slice to read from forcing file
data_subset = (
slice(
max(0, int(npx.argmax(xt_forc >= vs.xt.min())) - 1),
len(xt_forc) - max(0, int(npx.argmax(xt_forc[::-1] <= vs.xt.max())) - 1),
),
slice(
max(0, int(npx.argmax(yt_forc >= vs.yt.min())) - 1),
len(yt_forc) - max(0, int(npx.argmax(yt_forc[::-1] <= vs.yt.max())) - 1),
),
Ellipsis,
)
xt_forc = xt_forc[data_subset[0]]
yt_forc = yt_forc[data_subset[1]]
# initial conditions
temp_raw = self._get_data("temperature", idx=data_subset)[..., ::-1]
temp_data = veros.tools.interpolate((xt_forc, yt_forc, zt_forc), temp_raw, t_grid)
vs.temp = update(vs.temp, at[2:-2, 2:-2, :, :], (temp_data * vs.maskT[2:-2, 2:-2, :])[..., npx.newaxis])
salt_raw = self._get_data("salinity", idx=data_subset)[..., ::-1]
salt_data = veros.tools.interpolate((xt_forc, yt_forc, zt_forc), salt_raw, t_grid)
vs.salt = update(vs.salt, at[2:-2, 2:-2, :, :], (salt_data * vs.maskT[2:-2, 2:-2, :])[..., npx.newaxis])
# wind stress on MIT grid
time_grid = (vs.xt[2:-2], vs.yt[2:-2], npx.arange(12))
taux_raw = self._get_data("tau_x", idx=data_subset)
taux_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), taux_raw, time_grid)
vs.taux = update(vs.taux, at[2:-2, 2:-2, :], taux_data)
tauy_raw = self._get_data("tau_y", idx=data_subset)
tauy_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), tauy_raw, time_grid)
vs.tauy = update(vs.tauy, at[2:-2, 2:-2, :], tauy_data)
vs.taux = enforce_boundaries(vs.taux, settings.enable_cyclic_x)
vs.tauy = enforce_boundaries(vs.tauy, settings.enable_cyclic_x)
# Qnet and dQ/dT and Qsol
qnet_raw = self._get_data("q_net", idx=data_subset)
qnet_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), qnet_raw, time_grid)
vs.qnet = update(vs.qnet, at[2:-2, 2:-2, :], -qnet_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
qnec_raw = self._get_data("dqdt", idx=data_subset)
qnec_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), qnec_raw, time_grid)
vs.qnec = update(vs.qnec, at[2:-2, 2:-2, :], qnec_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
qsol_raw = self._get_data("swf", idx=data_subset)
qsol_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), qsol_raw, time_grid)
vs.qsol = update(vs.qsol, at[2:-2, 2:-2, :], -qsol_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
# SST and SSS
sst_raw = self._get_data("sst", idx=data_subset)
sst_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), sst_raw, time_grid)
vs.t_star = update(vs.t_star, at[2:-2, 2:-2, :], sst_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
sss_raw = self._get_data("sss", idx=data_subset)
sss_data = veros.tools.interpolate((xt_forc, yt_forc, npx.arange(12)), sss_raw, time_grid)
vs.s_star = update(vs.s_star, at[2:-2, 2:-2, :], sss_data * vs.maskT[2:-2, 2:-2, -1, npx.newaxis])
if settings.enable_idemix:
tidal_energy_raw = self._get_data("tidal_energy", idx=data_subset)
tidal_energy_data = veros.tools.interpolate((xt_forc, yt_forc), tidal_energy_raw, t_grid[:-1])
mask_x, mask_y = (i + 2 for i in npx.indices((vs.nx, vs.ny)))
mask_z = npx.maximum(0, vs.kbot[2:-2, 2:-2] - 1)
tidal_energy_data[:, :] *= vs.maskW[mask_x, mask_y, mask_z] / vs.rho_0
vs.forc_iw_bottom[2:-2, 2:-2] = tidal_energy_data
"""
Initialize penetration profile for solar radiation and store divergence in divpen
note that pen is set to 0.0 at the surface instead of 1.0 to compensate for the
shortwave part of the total surface flux
"""
swarg1 = vs.zw / efold1_shortwave
swarg2 = vs.zw / efold2_shortwave
pen = rpart_shortwave * npx.exp(swarg1) + (1.0 - rpart_shortwave) * npx.exp(swarg2)
pen = update(pen, at[-1], 0.0)
vs.divpen_shortwave = update(vs.divpen_shortwave, at[1:], (pen[1:] - pen[:-1]) / vs.dzt[1:])
vs.divpen_shortwave = update(vs.divpen_shortwave, at[0], pen[0] / vs.dzt[0])
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.update(set_forcing_kernel(state))
@veros_routine
def set_diagnostics(self, state):
settings = state.settings
diagnostics = state.diagnostics
diagnostics["cfl_monitor"].output_frequency = settings.dt_tracer * 100
diagnostics["tracer_monitor"].output_frequency = settings.dt_tracer * 100
diagnostics["snapshot"].output_frequency = 30 * 86400.0
diagnostics["overturning"].output_frequency = 360 * 86400
diagnostics["overturning"].sampling_frequency = 86400.0
diagnostics["energy"].output_frequency = 360 * 86400
diagnostics["energy"].sampling_frequency = 10 * settings.dt_tracer
diagnostics["averages"].output_frequency = 30 * 86400
diagnostics["averages"].sampling_frequency = settings.dt_tracer
average_vars = [
"surface_taux",
"surface_tauy",
"forc_temp_surface",
"forc_salt_surface",
"psi",
"temp",
"salt",
"u",
"v",
"w",
"Nsqr",
"Hd",
"rho",
"kappaH",
]
if settings.enable_skew_diffusion:
average_vars += ["B1_gm", "B2_gm"]
if settings.enable_TEM_friction:
average_vars += ["kappa_gm", "K_diss_gm"]
if settings.enable_tke:
average_vars += ["tke", "Prandtlnumber", "mxl", "tke_diss", "forc_tke_surface", "tke_surf_corr"]
if settings.enable_idemix:
average_vars += ["E_iw", "forc_iw_surface", "iw_diss", "c0", "v0"]
if settings.enable_eke:
average_vars += ["eke", "K_gm", "L_rossby", "L_rhines"]
diagnostics["averages"].output_variables = average_vars
@veros_routine
def after_timestep(self, state):
pass
@veros_kernel
def set_forcing_kernel(state):
vs = state.variables
settings = state.settings
t_rest = 30.0 * 86400.0
cp_0 = 3991.86795711963 # J/kg /K
year_in_seconds = veros.time.convert_time(1.0, "years", "seconds")
(n1, f1), (n2, f2) = veros.tools.get_periodic_interval(vs.time, year_in_seconds, year_in_seconds / 12.0, 12)
# linearly interpolate wind stress and shift from MITgcm U/V grid to this grid
vs.surface_taux = update(vs.surface_taux, at[:-1, :], f1 * vs.taux[1:, :, n1] + f2 * vs.taux[1:, :, n2])
vs.surface_tauy = update(vs.surface_tauy, at[:, :-1], f1 * vs.tauy[:, 1:, n1] + f2 * vs.tauy[:, 1:, n2])
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[1:-1, 1:-1],
npx.sqrt(
(0.5 * (vs.surface_taux[1:-1, 1:-1] + vs.surface_taux[:-2, 1:-1]) / settings.rho_0) ** 2
+ (0.5 * (vs.surface_tauy[1:-1, 1:-1] + vs.surface_tauy[1:-1, :-2]) / settings.rho_0) ** 2
)
** (3.0 / 2.0),
)
# W/m^2 K kg/J m^3/kg = K m/s
t_star_cur = f1 * vs.t_star[..., n1] + f2 * vs.t_star[..., n2]
qqnec = f1 * vs.qnec[..., n1] + f2 * vs.qnec[..., n2]
qqnet = f1 * vs.qnet[..., n1] + f2 * vs.qnet[..., n2]
vs.forc_temp_surface = (
(qqnet + qqnec * (t_star_cur - vs.temp[..., -1, vs.tau])) * vs.maskT[..., -1] / cp_0 / settings.rho_0
)
s_star_cur = f1 * vs.s_star[..., n1] + f2 * vs.s_star[..., n2]
vs.forc_salt_surface = 1.0 / t_rest * (s_star_cur - vs.salt[..., -1, vs.tau]) * vs.maskT[..., -1] * vs.dzt[-1]
# apply simple ice mask
mask1 = vs.temp[:, :, -1, vs.tau] * vs.maskT[:, :, -1] > -1.8
mask2 = vs.forc_temp_surface > 0
ice = npx.logical_or(mask1, mask2)
vs.forc_temp_surface *= ice
vs.forc_salt_surface *= ice
# solar radiation
if settings.enable_tempsalt_sources:
vs.temp_source = (
(f1 * vs.qsol[..., n1, None] + f2 * vs.qsol[..., n2, None])
* vs.divpen_shortwave[None, None, :]
* ice[..., None]
* vs.maskT[..., :]
/ cp_0
/ settings.rho_0
)
return KernelOutput(
surface_taux=vs.surface_taux,
surface_tauy=vs.surface_tauy,
temp_source=vs.temp_source,
forc_tke_surface=vs.forc_tke_surface,
forc_temp_surface=vs.forc_temp_surface,
forc_salt_surface=vs.forc_salt_surface,
)
from veros.setups.north_atlantic.north_atlantic import NorthAtlanticSetup # noqa: F401
{
"topography": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/north_atlantic/ETOPO1_Bed_g_gmt4_NA.nc",
"md5": "5bd5b5ea6f8ab529b9f74f560319b65d"
},
"forcing": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/north_atlantic/forcing_nc4.nc",
"md5": "9997ebf7798c0d1c28bcbdd406414f28"
},
"restoring": {
"url": "https://sid.erda.dk/share_redirect/gsdZADr8to/north_atlantic/restoring_zone.nc",
"md5": "618f6116e66be4349c1e7846dc44d06b"
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment