Commit ef746cfa authored by mashun1's avatar mashun1
Browse files

veros

parents
Pipeline #1302 canceled with stages
#!/usr/bin/env python
import os
import h5netcdf
from PIL import Image
import scipy.spatial
import scipy.ndimage
from veros import VerosSetup, veros_routine, veros_kernel, KernelOutput
from veros.variables import Variable
from veros.core.operators import numpy as npx, update, at
import veros.tools
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
DATA_FILES = veros.tools.get_assets("north_atlantic", os.path.join(BASE_PATH, "assets.json"))
TOPO_MASK_FILE = os.path.join(BASE_PATH, "topo_mask.png")
class NorthAtlanticSetup(VerosSetup):
"""A regional model of the North Atlantic, inspired by `Smith et al., 2000`_.
Forcing and initial conditions are taken from the FLAME PyOM2 setup. Bathymetry
data from ETOPO1 (resolution of 1 arcmin).
Boundary forcings are implemented via sponge layers in the Greenland Sea, by the
Strait of Gibraltar, and in the South Atlantic. This setup runs with arbitrary resolution;
upon changing the number of grid cells, all forcing files will be interpolated to
the new grid. Default resolution corresponds roughly to :math:`0.5 \\times 0.25` degrees.
.. _Smith et al., 2000:
http://journals.ametsoc.org/doi/10.1175/1520-0485%282000%29030%3C1532%3ANSOTNA%3E2.0.CO%3B2
"""
x_boundary = 17.2
y_boundary = 70.0
max_depth = 5800.0
@veros_routine
def set_parameter(self, state):
settings = state.settings
settings.identifier = "north_atlantic"
settings.description = "North Atlantic setup"
settings.nx, settings.ny, settings.nz = 250, 350, 50
settings.x_origin = -98.0
settings.y_origin = -18.0
settings.dt_mom = 3600.0 / 2.0
settings.dt_tracer = 3600.0 / 2.0
settings.runlen = 86400 * 365.0 * 10.0
settings.coord_degree = True
settings.enable_neutral_diffusion = True
settings.enable_skew_diffusion = True
settings.K_iso_0 = 1000.0
settings.K_iso_steep = 200.0
settings.iso_dslope = 1.0 / 1000.0
settings.iso_slopec = 4.0 / 1000.0
settings.enable_hor_friction = True
settings.A_h = 1e3
settings.enable_hor_friction_cos_scaling = True
settings.hor_friction_cosPower = 1
settings.enable_tempsalt_sources = True
settings.enable_implicit_vert_friction = True
settings.enable_tke = True
settings.c_k = 0.1
settings.c_eps = 0.7
settings.alpha_tke = 30.0
settings.mxl_min = 1e-8
settings.tke_mxl_choice = 2
settings.kappaM_min = 2e-4
settings.kappaH_min = 2e-5
settings.enable_kappaH_profile = True
settings.K_gm_0 = 1000.0
settings.enable_eke = False
settings.enable_idemix = False
settings.eq_of_state_type = 5
state.dimensions["nmonths"] = 12
state.var_meta.update(
{
"sss_clim": Variable("sss_clim", ("xt", "yt", "nmonths"), "g/kg", "Monthly sea surface salinity"),
"sst_clim": Variable("sst_clim", ("xt", "yt", "nmonths"), "deg C", "Monthly sea surface temperature"),
"sss_rest": Variable(
"sss_rest", ("xt", "yt", "nmonths"), "g/kg", "Monthly sea surface salinity restoring"
),
"sst_rest": Variable(
"sst_rest", ("xt", "yt", "nmonths"), "deg C", "Monthly sea surface temperature restoring"
),
"t_star": Variable(
"t_star", ("xt", "yt", "zt", "nmonths"), "deg C", "Temperature sponge layer forcing"
),
"s_star": Variable("s_star", ("xt", "yt", "zt", "nmonths"), "g/kg", "Salinity sponge layer forcing"),
"rest_tscl": Variable("rest_tscl", ("xt", "yt", "zt"), "1/s", "Forcing restoration time scale"),
"taux": Variable("taux", ("xt", "yt", "nmonths"), "N/s^2", "Monthly zonal wind stress"),
"tauy": Variable("tauy", ("xt", "yt", "nmonths"), "N/s^2", "Monthly meridional wind stress"),
}
)
@veros_routine
def set_grid(self, state):
vs = state.variables
settings = state.settings
vs.dxt = update(vs.dxt, at[2:-2], (self.x_boundary - settings.x_origin) / settings.nx)
vs.dyt = update(vs.dyt, at[2:-2], (self.y_boundary - settings.y_origin) / settings.ny)
vs.dzt = veros.tools.get_vinokur_grid_steps(settings.nz, self.max_depth, 10.0, refine_towards="lower")
@veros_routine
def set_coriolis(self, state):
vs = state.variables
settings = state.settings
vs.coriolis_t = update(
vs.coriolis_t, at[...], 2 * settings.omega * npx.sin(vs.yt[npx.newaxis, :] / 180.0 * settings.pi)
)
@veros_routine(dist_safe=False, local_variables=["kbot", "xt", "yt", "zt"])
def set_topography(self, state):
import numpy as onp
vs = state.variables
settings = state.settings
with h5netcdf.File(DATA_FILES["topography"], "r") as topo_file:
topo_x, topo_y, topo_bottom_depth = (self._get_data(topo_file, k) for k in ("x", "y", "z"))
topo_mask = npx.flipud(npx.asarray(Image.open(TOPO_MASK_FILE))).T
topo_bottom_depth = npx.where(topo_mask, 0, topo_bottom_depth)
topo_bottom_depth = scipy.ndimage.gaussian_filter(
topo_bottom_depth, sigma=(len(topo_x) / settings.nx, len(topo_y) / settings.ny)
)
interp_coords = npx.meshgrid(vs.xt[2:-2], vs.yt[2:-2], indexing="ij")
interp_coords = npx.rollaxis(npx.asarray(interp_coords), 0, 3)
z_interp = scipy.interpolate.interpn(
(onp.array(topo_x), onp.array(topo_y)),
topo_bottom_depth,
onp.array(interp_coords),
method="nearest",
bounds_error=False,
fill_value=0,
)
vs.kbot = update(
vs.kbot,
at[2:-2, 2:-2],
npx.where(
z_interp < 0.0,
1 + npx.argmin(npx.abs(z_interp[:, :, npx.newaxis] - vs.zt[npx.newaxis, npx.newaxis, :]), axis=2),
0,
),
)
vs.kbot = npx.where(vs.kbot < settings.nz, vs.kbot, 0)
def _get_data(self, f, var):
"""Retrieve variable from h5netcdf file"""
var_obj = f.variables[var]
return npx.array(var_obj).T
@veros_routine(
dist_safe=False,
local_variables=[
"tau",
"xt",
"yt",
"zt",
"temp",
"maskT",
"salt",
"taux",
"tauy",
"sst_clim",
"sss_clim",
"sst_rest",
"sss_rest",
"t_star",
"s_star",
"rest_tscl",
],
)
def set_initial_conditions(self, state):
vs = state.variables
with h5netcdf.File(DATA_FILES["forcing"], "r") as forcing_file:
t_hor = (vs.xt[2:-2], vs.yt[2:-2])
t_grid = (vs.xt[2:-2], vs.yt[2:-2], vs.zt)
forc_coords = [self._get_data(forcing_file, k) for k in ("xt", "yt", "zt")]
forc_coords[0] = forc_coords[0] - 360
forc_coords[2] = -0.01 * forc_coords[2][::-1]
temp_raw = self._get_data(forcing_file, "temp_ic")[..., ::-1]
temp = veros.tools.interpolate(forc_coords, temp_raw, t_grid, missing_value=-1e20)
vs.temp = update(vs.temp, at[2:-2, 2:-2, :, vs.tau], vs.maskT[2:-2, 2:-2, :] * temp)
salt_raw = self._get_data(forcing_file, "salt_ic")[..., ::-1]
salt = 35.0 + 1000 * veros.tools.interpolate(forc_coords, salt_raw, t_grid, missing_value=-1e20)
vs.salt = update(vs.salt, at[2:-2, 2:-2, :, vs.tau], vs.maskT[2:-2, 2:-2, :] * salt)
forc_u_coords_hor = [self._get_data(forcing_file, k) for k in ("xu", "yu")]
forc_u_coords_hor[0] = forc_u_coords_hor[0] - 360
taux = self._get_data(forcing_file, "taux")
tauy = self._get_data(forcing_file, "tauy")
for k in range(12):
vs.taux = update(
vs.taux,
at[2:-2, 2:-2, k],
(veros.tools.interpolate(forc_u_coords_hor, taux[..., k], t_hor, missing_value=-1e20) / 10.0),
)
vs.tauy = update(
vs.tauy,
at[2:-2, 2:-2, k],
(veros.tools.interpolate(forc_u_coords_hor, tauy[..., k], t_hor, missing_value=-1e20) / 10.0),
)
# heat flux and salinity restoring
sst_clim, sss_clim, sst_rest, sss_rest = [
forcing_file.variables[k][...].T for k in ("sst_clim", "sss_clim", "sst_rest", "sss_rest")
]
for k in range(12):
vs.sst_clim = update(
vs.sst_clim,
at[2:-2, 2:-2, k],
veros.tools.interpolate(forc_coords[:-1], sst_clim[..., k], t_hor, missing_value=-1e20),
)
vs.sss_clim = update(
vs.sss_clim,
at[2:-2, 2:-2, k],
(veros.tools.interpolate(forc_coords[:-1], sss_clim[..., k], t_hor, missing_value=-1e20) * 1000 + 35),
)
vs.sst_rest = update(
vs.sst_rest,
at[2:-2, 2:-2, k],
(veros.tools.interpolate(forc_coords[:-1], sst_rest[..., k], t_hor, missing_value=-1e20) * 41868.0),
)
vs.sss_rest = update(
vs.sss_rest,
at[2:-2, 2:-2, k],
(veros.tools.interpolate(forc_coords[:-1], sss_rest[..., k], t_hor, missing_value=-1e20) / 100.0),
)
with h5netcdf.File(DATA_FILES["restoring"], "r") as restoring_file:
rest_coords = [self._get_data(restoring_file, k) for k in ("xt", "yt", "zt")]
rest_coords[0] = rest_coords[0] - 360
# sponge layers
vs.rest_tscl = update(
vs.rest_tscl,
at[2:-2, 2:-2, :],
veros.tools.interpolate(rest_coords, self._get_data(restoring_file, "tscl")[..., 0], t_grid),
)
t_star = self._get_data(restoring_file, "t_star")
s_star = self._get_data(restoring_file, "s_star")
for k in range(12):
vs.t_star = update(
vs.t_star,
at[2:-2, 2:-2, :, k],
veros.tools.interpolate(rest_coords, t_star[..., k], t_grid, missing_value=0.0),
)
vs.s_star = update(
vs.s_star,
at[2:-2, 2:-2, :, k],
veros.tools.interpolate(rest_coords, s_star[..., k], t_grid, missing_value=0.0),
)
@veros_routine
def set_forcing(self, state):
vs = state.variables
vs.update(set_forcing_kernel(state))
@veros_routine
def set_diagnostics(self, state):
diagnostics = state.diagnostics
settings = state.settings
diagnostics["snapshot"].output_frequency = 3600.0 * 24 * 10
diagnostics["averages"].output_frequency = 3600.0 * 24 * 10
diagnostics["averages"].sampling_frequency = settings.dt_tracer
diagnostics["averages"].output_variables = [
"temp",
"salt",
"u",
"v",
"w",
"surface_taux",
"surface_tauy",
"psi",
]
diagnostics["cfl_monitor"].output_frequency = settings.dt_tracer * 10
@veros_routine
def after_timestep(self, state):
pass
@veros_kernel
def set_forcing_kernel(state):
vs = state.variables
settings = state.settings
year_in_seconds = 360 * 86400.0
(n1, f1), (n2, f2) = veros.tools.get_periodic_interval(vs.time, year_in_seconds, year_in_seconds / 12.0, 12)
vs.surface_taux = f1 * vs.taux[:, :, n1] + f2 * vs.taux[:, :, n2]
vs.surface_tauy = f1 * vs.tauy[:, :, n1] + f2 * vs.tauy[:, :, n2]
if settings.enable_tke:
vs.forc_tke_surface = update(
vs.forc_tke_surface,
at[1:-1, 1:-1],
npx.sqrt(
(0.5 * (vs.surface_taux[1:-1, 1:-1] + vs.surface_taux[:-2, 1:-1]) / settings.rho_0) ** 2
+ (0.5 * (vs.surface_tauy[1:-1, 1:-1] + vs.surface_tauy[1:-1, :-2]) / settings.rho_0) ** 2
)
** 1.5,
)
cp_0 = 3991.86795711963
vs.forc_temp_surface = (
(f1 * vs.sst_rest[:, :, n1] + f2 * vs.sst_rest[:, :, n2])
* (f1 * vs.sst_clim[:, :, n1] + f2 * vs.sst_clim[:, :, n2] - vs.temp[:, :, -1, vs.tau])
* vs.maskT[:, :, -1]
/ cp_0
/ settings.rho_0
)
vs.forc_salt_surface = (
(f1 * vs.sss_rest[:, :, n1] + f2 * vs.sss_rest[:, :, n2])
* (f1 * vs.sss_clim[:, :, n1] + f2 * vs.sss_clim[:, :, n2] - vs.salt[:, :, -1, vs.tau])
* vs.maskT[:, :, -1]
)
ice_mask = (vs.temp[:, :, -1, vs.tau] * vs.maskT[:, :, -1] <= -1.8) & (vs.forc_temp_surface <= 0.0)
vs.forc_temp_surface = npx.where(ice_mask, 0.0, vs.forc_temp_surface)
vs.forc_salt_surface = npx.where(ice_mask, 0.0, vs.forc_salt_surface)
if settings.enable_tempsalt_sources:
vs.temp_source = (
vs.maskT
* vs.rest_tscl
* (f1 * vs.t_star[:, :, :, n1] + f2 * vs.t_star[:, :, :, n2] - vs.temp[:, :, :, vs.tau])
)
vs.salt_source = (
vs.maskT
* vs.rest_tscl
* (f1 * vs.s_star[:, :, :, n1] + f2 * vs.s_star[:, :, :, n2] - vs.salt[:, :, :, vs.tau])
)
return KernelOutput(
surface_taux=vs.surface_taux,
surface_tauy=vs.surface_tauy,
temp_source=vs.temp_source,
salt_source=vs.salt_source,
forc_tke_surface=vs.forc_tke_surface,
forc_temp_surface=vs.forc_temp_surface,
forc_salt_surface=vs.forc_salt_surface,
)
import signal
import contextlib
import functools
from veros import logger
def do_not_disturb(function):
"""Decorator that catches SIGINT and SIGTERM signals (e.g. after keyboard interrupt)
and makes sure that the function body is executed before exiting.
Useful for ensuring that output files are written properly.
"""
signals = (signal.SIGINT, signal.SIGTERM)
@functools.wraps(function)
def dnd_wrapper(*args, **kwargs):
old_handlers = {s: signal.getsignal(s) for s in signals}
signal_received = {"sig": None, "frame": None}
def handler(sig, frame):
if signal_received["sig"] is None:
signal_received["sig"] = sig
signal_received["frame"] = frame
logger.error(f"Signal {sig} received - cleaning up before exit")
else:
# force quit if more than one signal is received
old_handlers[sig](sig, frame)
for s in signals:
signal.signal(s, handler)
try:
res = function(*args, **kwargs)
finally:
for s in signals:
signal.signal(s, old_handlers[s])
sig = signal_received["sig"]
if sig is not None:
old_handlers[sig](signal_received["sig"], signal_received["frame"])
return res
return dnd_wrapper
@contextlib.contextmanager
def signals_to_exception(signals=(signal.SIGINT, signal.SIGTERM)):
"""Context manager that converts system signals to exceptions.
This allows for a graceful exit after receiving SIGTERM (e.g. through
`kill` on UNIX systems).
Example:
>>> with signals_to_exception():
>>> try:
>>> # do something
>>> except SystemExit:
>>> # graceful exit even upon receiving interrupt signal
"""
def signal_to_exception(sig, frame):
logger.critical("Received interrupt signal {}", sig)
raise SystemExit("Aborted")
old_signals = {}
for s in signals:
# override signals with our handler
old_signals[s] = signal.getsignal(s)
signal.signal(s, signal_to_exception)
try:
yield
finally:
# re-attach old signals
for s in signals:
signal.signal(s, old_signals[s])
import contextlib
from collections import defaultdict, namedtuple
from collections.abc import Mapping
from copy import deepcopy
from veros import (
timer,
plugins,
settings as settings_mod,
variables as var_mod,
runtime_settings as rs,
runtime_state as rst,
)
def make_namedtuple(**kwargs):
return namedtuple("KernelOutput", list(kwargs.keys()))(*kwargs.values())
KernelOutput = make_namedtuple
class StrictContainer:
"""A mutable container with fixed fields (optionally typed)."""
__fields__ = ()
__field_types__ = ()
def __init__(self, fields, *args, field_types=None, default=None, **kwargs):
self.__fields__ = fields
if field_types is None:
self.__field_types__ = {}
else:
if not isinstance(field_types, dict) or not set(field_types.keys()) <= set(fields):
raise ValueError("field_types must be a dict with fields as keys")
self.__field_types__ = field_types
for k in fields:
if k in vars(self):
raise ValueError(f"Name collision: {k}")
if k.startswith("_"):
raise ValueError(f"Fields cannot start with _ (got: {k}).")
super().__setattr__(k, default)
def __setattr__(self, key, val):
if not key.startswith("_") and key not in self.__fields__:
raise AttributeError(f"Unknown attribute {key}")
if key in self.__field_types__:
val = self.__field_types__[key](val)
return super().__setattr__(key, val)
def __contains__(self, val):
return val in self.__fields__
def fields(self):
return self.__fields__
def values(self):
return (getattr(self, k) for k in self.__fields__)
def items(self):
return ((k, getattr(self, k)) for k in self.__fields__)
def update(self, other=None, **new_fields):
if other is not None:
if new_fields:
raise ValueError("Either other or new_fields can be given")
if hasattr(other, "_fields"):
# other is namedtuple
new_fields = dict(zip(other._fields, other))
elif isinstance(other, (dict, StrictContainer)):
new_fields = other
else:
raise TypeError(f"Cannot update from {type(other)} type")
for key, val in new_fields.items():
if key not in self.__fields__:
raise AttributeError(f"unknown attribute {key}")
for key, val in new_fields.items():
setattr(self, key, val)
return self
def get(self, key, default=None):
return getattr(self, key, default)
def todict(self):
return {k: getattr(self, k) for k in self.__fields__}
def __repr__(self):
attr_str = []
for key, val in self.items():
# poor-man's check for array-compatible types
if hasattr(val, "shape") and hasattr(val, "dtype"):
val_repr = f"{type(val)} with shape {val.shape}, dtype {val.dtype}"
else:
val_repr = repr(val)
attr_str.append(f" {key} = {val_repr}")
attr_str = ",\n".join(attr_str)
return f"{self.__class__.__qualname__}(\n{attr_str}\n)"
class Lockable:
__locked__ = True
@contextlib.contextmanager
def unlock(self):
lock_state = self.__locked__
try:
self.__locked__ = False
yield
finally:
self.__locked__ = lock_state
def __setattr__(self, key, val):
if not key.startswith("_") and self.__locked__:
clsname = self.__class__.__qualname__
raise RuntimeError(
f"{clsname} is locked to modifications. If you know what you are doing, "
f'you can unlock it via the "{clsname}.unlock()" context manager.'
)
return super().__setattr__(key, val)
class StaticDictProxy(Mapping):
def __init__(self, content, writeback=None):
self._wrapped = content
self._writeback = writeback
def __len__(self):
return self._wrapped.__len__()
def __iter__(self):
return self._wrapped.__iter__()
def __getitem__(self, key):
return self._wrapped.__getitem__(key)
def __setitem__(self, key, val):
if key in self:
raise RuntimeError("Cannot overwrite existing values")
if self._writeback is not None:
self._writeback.__setitem__(key, val)
self._wrapped.__setitem__(key, val)
def __repr__(self):
return f"{self.__class__.__qualname__}({self._wrapped!r})"
class VerosSettings(Lockable, StrictContainer):
def __init__(self, settings_meta):
self.__metadata__ = settings_meta
super().__init__(fields=settings_meta.keys())
default_settings = {k: meta.type(meta.default) for k, meta in settings_meta.items()}
with self.unlock():
self.update(default_settings)
def __setattr__(self, key, val):
if key.startswith("_") or key not in self.__metadata__:
return super().__setattr__(key, val)
meta = self.__metadata__[key]
val = meta.type(val)
return super().__setattr__(key, val)
class VerosVariables(Lockable, StrictContainer):
""" """
def __init__(self, var_meta, dimensions):
self.__metadata__ = var_meta
self.__dimensions__ = dimensions
active_vars = [key for key, val in var_meta.items() if val.active]
super().__init__(fields=active_vars)
with self.unlock():
for key, val in var_meta.items():
if not val.active:
continue
allocate_kwargs = dict(dtype=val.dtype)
if val.initial is not None:
allocate_kwargs.update(fill=val.initial)
setattr(self, key, var_mod.allocate(dimensions, val.dims, **allocate_kwargs))
def __getattr__(self, attr):
orig_getattr = super().__getattribute__
try:
var = orig_getattr("__metadata__")[attr]
except (KeyError, AttributeError):
return orig_getattr(attr)
if not var.active:
raise RuntimeError(
f"Variable {attr} is not active in this configuration. Check your settings and try again."
)
return orig_getattr(attr)
def __setattr__(self, key, val):
if key.startswith("_") or key not in self.__metadata__:
return super().__setattr__(key, val)
var = self.__metadata__[key]
# check whether variable is active
if not var.active:
raise RuntimeError(
f"Variable {key} is not active in this configuration. Check your settings and try again."
)
# validate array type, shape and dtype
if var.dtype is not None:
expected_dtype = var.dtype
else:
expected_dtype = rs.float_type
val = rst.backend_module.asarray(val, dtype=expected_dtype)
expected_shape = self._get_expected_shape(var.dims)
if val.shape != expected_shape:
raise ValueError(f"Got unexpected shape for variable {key} (expected: {expected_shape}, got: {val.shape})")
return super().__setattr__(key, val)
def _get_expected_shape(self, dims):
return var_mod.get_shape(self.__dimensions__, dims)
class DistSafeVariableWrapper(VerosVariables):
def __init__(self, parent_state, local_variables):
# set internal attributes to be identical to given variables object
for attr, val in vars(parent_state).items():
if not attr.startswith("__"):
continue
super().__setattr__(attr, val)
self.__parent_state__ = parent_state
self.__local_variables__ = local_variables
def __getattr__(self, attr):
orig_getattr = super().__getattribute__
if attr in orig_getattr("__metadata__") and attr not in orig_getattr("__local_variables__"):
raise RuntimeError(
f"Cannot access variable {attr} because it was not collected. "
"Consider adding it to the local_variables argument of @veros_routine."
)
return orig_getattr(attr)
def __setattr__(self, attr, val):
if attr.startswith("_"):
return super().__setattr__(attr, val)
if attr in self.__metadata__ and attr not in self.__local_variables__:
raise RuntimeError(
f"Cannot access variable {attr} because it was not collected. "
"Consider adding it to the local_variables argument of @veros_routine."
)
return super().__setattr__(attr, val)
def _gather_variables(self):
from veros.distributed import gather
var_meta = self.__metadata__
for var in self.__local_variables__:
if var not in var_meta:
raise ValueError(f"encountered unknown variable {var} in local variables")
if not var_meta[var].active:
continue
gathered_var = gather(getattr(self.__parent_state__, var), self.__dimensions__, self.__metadata__[var].dims)
setattr(self, var, gathered_var)
def _scatter_variables(self):
from veros.distributed import scatter, barrier
barrier()
var_meta = self.__metadata__
for var in self.__local_variables__:
if var not in var_meta:
raise ValueError(f"encountered unknown variable {var} in local variables")
if not var_meta[var].active:
continue
scattered_var = scatter(getattr(self, var), self.__dimensions__, self.__metadata__[var].dims)
setattr(self.__parent_state__, var, scattered_var)
def _get_expected_shape(self, dims):
return var_mod.get_shape(self.__dimensions__, dims, local=rst.proc_rank != 0)
def __repr__(self):
return f"{self.__class__.__qualname__}(parent_state={self.__parent_state__}, local_variables={self.__local_variables__})"
class VerosState:
"""Holds all settings and model state for a given Veros run."""
def __init__(self, var_meta, setting_meta, dimensions, diagnostics=None, plugin_interfaces=None):
self._var_meta = var_meta
self._variables = None
self._settings = VerosSettings(setting_meta)
self._dimensions = dimensions
if diagnostics is not None:
self._diagnostics = diagnostics
else:
self._diagnostics = {}
if plugin_interfaces is not None:
self._plugin_interfaces = plugin_interfaces
else:
self._plugin_interfaces = ()
timer_factory = timer.Timer
self.timers = defaultdict(timer_factory)
self.profile_timers = defaultdict(timer_factory)
def __repr__(self):
from textwrap import indent
attr_str = []
for attr in ("settings", "dimensions", "variables", "diagnostics", "plugin_interfaces"):
# indent all lines of attr repr except the first
attr_val = indent(repr(getattr(self, f"_{attr}")), " " * 4)[4:]
attr_str.append(f" {attr} = {attr_val}")
attr_str = ",\n".join(attr_str)
return f"{self.__class__.__qualname__}(\n{attr_str}\n)"
def initialize_variables(self):
if self._variables is not None:
raise RuntimeError("Variables are already initialized.")
self._var_meta = var_mod.manifest_metadata(self._var_meta, self._settings)
self._variables = VerosVariables(self._var_meta, self._manifest_dimensions())
@property
def var_meta(self):
return self._var_meta
@property
def variables(self):
if self._variables is None:
raise RuntimeError("Variables have not been initialized yet.")
return self._variables
@property
def settings(self):
return self._settings
def _manifest_dimensions(self):
concrete_dimensions = {}
for dim_name, dim_target in self._dimensions.items():
if isinstance(dim_target, str):
dim_size = getattr(self._settings, dim_target)
else:
dim_size = dim_target
concrete_dimensions[dim_name] = int(dim_size)
return concrete_dimensions
@property
def dimensions(self):
concrete_dimensions = self._manifest_dimensions()
return StaticDictProxy(concrete_dimensions, self._dimensions)
@property
def diagnostics(self):
return self._diagnostics
@property
def plugin_interfaces(self):
return self._plugin_interfaces
def to_xarray(self):
import xarray as xr
vs = self.variables
coords = {}
data_vars = {}
for var_name, var_meta in self.var_meta.items():
if not var_meta.active:
continue
data = var_mod.remove_ghosts(vs.get(var_name), var_meta.dims)
data_vars[var_name] = xr.DataArray(
data,
dims=var_meta.dims,
name=var_name,
attrs=dict(
long_description=var_meta.long_description,
units=var_meta.units,
scale=var_meta.scale,
),
)
if var_meta.dims is None:
continue
for dim in var_meta.dims:
if dim not in coords:
coords[dim] = range(var_mod.get_shape(self.dimensions, (dim,), include_ghosts=False)[0])
data_vars = {k: v for k, v in data_vars.items() if k not in coords}
attrs = dict(self.settings.items())
return xr.Dataset(data_vars, coords=coords, attrs=attrs)
def get_default_state(plugin_interfaces=()):
if isinstance(plugin_interfaces, plugins.VerosPlugin):
plugin_interfaces = [plugin_interfaces]
for plugin in plugin_interfaces:
if not isinstance(plugin, plugins.VerosPlugin):
raise TypeError(f"Got unexpected type {type(plugin)}")
settings = deepcopy(settings_mod.SETTINGS)
dimensions = deepcopy(var_mod.DIM_TO_SHAPE_VAR)
var_meta = deepcopy(var_mod.VARIABLES)
for plugin in plugin_interfaces:
settings.update(plugin.settings)
var_meta.update(plugin.variables)
dimensions.update(plugin.dimensions)
return VerosState(var_meta, settings, dimensions, plugin_interfaces=plugin_interfaces)
def veros_state_pytree_flatten(state):
aux_data = tuple((k, v) for k, v in vars(state).items() if k != "_variables")
# ensure that functions are re-traced when settings change
with state.settings.unlock():
pseudo_hash = hash(tuple(state.settings.items()))
return ([state.variables], (aux_data, pseudo_hash))
def veros_state_pytree_unflatten(aux_data, leaves):
assert len(leaves) == 1
variables = leaves[0]
# by-pass __init__ and set attributes manually
state = VerosState.__new__(VerosState)
state._variables = variables
state_attrs, _ = aux_data
for attr, val in state_attrs:
setattr(state, attr, val)
return state
def veros_variables_pytree_flatten(variables):
aux_attrs = (
"__dimensions__",
"__metadata__",
"__fields__",
"__locked__",
)
leaves = list(variables.values())
aux_data = (tuple(variables.fields()), tuple((attr, getattr(variables, attr)) for attr in aux_attrs))
return (leaves, aux_data)
def veros_variables_pytree_unflatten(aux_data, leaves):
keys, aux_attrs = aux_data
# by-pass __init__ and set attributes manually
variables = VerosVariables.__new__(VerosVariables)
for key, val in aux_attrs:
setattr(variables, key, val)
with variables.unlock():
for key, val in zip(keys, leaves):
setattr(variables, key, val)
return variables
def dist_safe_wrapper_pytree_flatten(variables):
aux_attrs = (
"__dimensions__",
"__metadata__",
"__fields__",
"__locked__",
"__local_variables__",
"__parent_state__",
)
with variables.unlock():
leaves = [getattr(variables, attr) for attr in variables.__local_variables__]
aux_data = (tuple(variables.__local_variables__), tuple((attr, getattr(variables, attr)) for attr in aux_attrs))
return (leaves, aux_data)
def dist_safe_wrapper_pytree_unflatten(aux_data, leaves):
keys, aux_attrs = aux_data
# by-pass __init__ and set attributes manually
variables = DistSafeVariableWrapper.__new__(DistSafeVariableWrapper)
for key, val in aux_attrs:
setattr(variables, key, val)
with variables.unlock():
for key, val in zip(keys, leaves):
setattr(variables, key, val)
return variables
def resize_dimension(state, dimension, new_size):
"""Resize a dimension of an existing VerosState object.
This re-allocates all variables using the dimension to 0.
"""
state._dimensions[dimension] = new_size
state.variables.__dimensions__[dimension] = new_size
with state.variables.unlock():
for var in state.variables.fields():
var_meta = state.variables.__metadata__[var]
var_dims = var_meta.dims
if var_dims is None or dimension not in var_dims:
continue
setattr(state.variables, var, var_mod.allocate(state.dimensions, var_meta.dims, dtype=var_meta.dtype))
YEAR_LENGTH = 360.0
X_TO_SECONDS = {
"seconds": 1.0,
"minutes": 60.0,
"hours": 60.0 * 60.0,
"days": 24.0 * 60.0 * 60.0,
"years": YEAR_LENGTH * 24.0 * 60.0 * 60.0,
}
SECONDS_TO_X = {key: 1.0 / val for key, val in X_TO_SECONDS.items()}
def convert_time(time_value, in_unit, out_unit):
return time_value * X_TO_SECONDS[in_unit] * SECONDS_TO_X[out_unit]
def format_time(time_value, in_unit="seconds"):
all_units = X_TO_SECONDS.keys()
val_in_all_units = {u: convert_time(time_value, in_unit, u) for u in all_units}
valid_units = {u: v for u, v in val_in_all_units.items() if v >= 1.0}
if valid_units:
best_unit = min(valid_units, key=valid_units.get)
else:
best_unit = "seconds"
return val_in_all_units[best_unit], best_unit
import timeit
import threading
timer_context = threading.local()
timer_context.active = True
class Timer:
def __init__(self):
self.total_time = 0
self.last_time = 0
def __enter__(self):
self.start_time = timeit.default_timer()
def __exit__(self, *args, **kwargs):
self.last_time = timeit.default_timer() - self.start_time
if timer_context.active:
self.total_time += self.last_time
from veros.tools.assets import get_assets # noqa: F401
from veros.tools.setup import ( # noqa: F401
interpolate,
fill_holes,
get_periodic_interval,
make_cyclic,
get_coastline_distance,
get_uniform_grid_steps,
get_stretched_grid_steps,
get_vinokur_grid_steps,
)
import os
import json
import shutil
import hashlib
import urllib.parse as urlparse
import requests
from veros.tools.filelock import FileLock
from veros import logger, runtime_state
ASSET_DIRECTORY = os.environ.get("VEROS_ASSET_DIR") or os.path.join(os.path.expanduser("~"), ".veros", "assets")
class AssetError(Exception):
pass
class AssetStore:
def __init__(self, asset_dir, asset_config, skip_md5=False):
self._asset_dir = asset_dir
self._asset_config = asset_config
self._stored_assets = {}
self._skip_md5 = skip_md5
def _get_asset(self, key):
url = self._asset_config[key]["url"]
md5 = self._asset_config[key].get("md5")
skip_md5 = self._skip_md5
target_filename = os.path.basename(urlparse.urlparse(url).path)
target_path = os.path.join(self._asset_dir, target_filename)
target_lock = target_path + ".lock"
with FileLock(target_lock):
if not os.path.isfile(target_path):
logger.info("Downloading asset {} ...", target_filename)
_download_file(url, target_path)
# always validate freshly downloaded files
skip_md5 = False
check_md5 = not skip_md5 and md5 is not None and runtime_state.proc_rank == 0
if check_md5:
if _filehash(target_path) != md5:
raise AssetError(f"Mismatching MD5 checksum on asset {target_filename}")
return target_path
def keys(self):
return self._asset_config.keys()
def __contains__(self, key):
return key in self.keys()
def __getitem__(self, key):
if key not in self:
raise KeyError(f"unknown asset {key}")
if key not in self._stored_assets:
self._stored_assets[key] = self._get_asset(key)
return self._stored_assets[key]
def __repr__(self):
out = f"{self.__class__.__name__}(asset_dir={self._asset_dir}, asset_config={self._asset_config})"
return out
def get_assets(asset_id, asset_file, skip_md5=False):
"""Handles automatic download and verification of external assets (such as forcing files).
By default, assets are stored in ``$HOME/.veros/assets`` (can be overwritten by setting
``VEROS_ASSET_DIR`` environment variable to the desired location).
Arguments:
asset_id (str): Identifier of the collection of assets. Should be unique for each setup.
asset_file (str): JSON file containing URLs and (optionally) MD5 hashsums of each asset.
skip_md5 (bool): Whether to skip MD5 checksum validation (useful for huge asset files)
Returns:
A ``dict``-like mapping of each asset to file name on disk. Assets are downloaded lazily.
Example:
>>> assets = get_assets('mysetup', 'assets.json')
>>> assets['forcing']
"/home/user/.veros/assets/mysetup/mysetup_forcing.h5",
"initial_conditions": "/home/user/.veros/assets/mysetup/initial.h5"
}
In this case, ``assets.json`` contains::
{
"forcing": {
"url": "https://mywebsite.com/veros_assets/mysetup_forcing.h5",
"md5": "ef3be0a58782771c8ee5a6d0206b87f6"
},
"initial_conditions": {
"url": "https://mywebsite.com/veros_assets/initial.h5",
"md5": "d1b4e0e199d7a5883cf7c88d3d6bcb28"
}
}
"""
with open(asset_file, "r") as f:
assets = json.load(f)
asset_dir = os.path.join(ASSET_DIRECTORY, asset_id)
if not os.path.isdir(asset_dir):
try: # possible race-condition
os.makedirs(asset_dir)
except OSError:
if os.path.isdir(asset_dir):
pass
return AssetStore(asset_dir, assets, skip_md5)
def _download_file(url, target_path, timeout=10):
"""Download a file and save it to a folder"""
tmpfile = f"{target_path}.incomplete"
with requests.get(url, stream=True, timeout=timeout) as response:
response.raise_for_status()
response.raw.decode_content = True
try:
with open(tmpfile, "wb") as dst:
shutil.copyfileobj(response.raw, dst)
except: # noqa: E722
os.remove(tmpfile)
raise
shutil.move(tmpfile, target_path)
return target_path
def _filehash(path):
hash_md5 = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
# This is free and unencumbered software released into the public domain.
#
# Anyone is free to copy, modify, publish, use, compile, sell, or
# distribute this software, either in source code form or as a compiled
# binary, for any purpose, commercial or non-commercial, and by any
# means.
#
# In jurisdictions that recognize copyright laws, the author or authors
# of this software dedicate any and all copyright interest in the
# software to the public domain. We make this dedication for the benefit
# of the public at large and to the detriment of our heirs and
# successors. We intend this dedication to be an overt act of
# relinquishment in perpetuity of all present and future rights to this
# software under copyright law.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.
#
# For more information, please refer to <http://unlicense.org>
"""
A platform independent file lock that supports the with-statement.
"""
# Modules
# ------------------------------------------------
import logging
import os
import threading
import time
try:
import warnings
except ImportError:
warnings = None
try:
import msvcrt
except ImportError:
msvcrt = None
try:
import fcntl
except ImportError:
fcntl = None
# Backward compatibility
# ------------------------------------------------
try:
TimeoutError
except NameError:
TimeoutError = OSError
# Data
# ------------------------------------------------
__all__ = ["Timeout", "BaseFileLock", "WindowsFileLock", "UnixFileLock", "SoftFileLock", "FileLock"]
__version__ = "3.0.12"
_logger = None
def logger():
"""Returns the logger instance used in this module."""
global _logger
_logger = _logger or logging.getLogger(__name__)
return _logger
# Exceptions
# ------------------------------------------------
class Timeout(TimeoutError):
"""
Raised when the lock could not be acquired in *timeout*
seconds.
"""
def __init__(self, lock_file):
""" """
#: The path of the file lock.
self.lock_file = lock_file
return None
def __str__(self):
temp = f"The file lock '{self.lock_file}' could not be acquired."
return temp
# Classes
# ------------------------------------------------
# This is a helper class which is returned by :meth:`BaseFileLock.acquire`
# and wraps the lock to make sure __enter__ is not called twice when entering
# the with statement.
# If we would simply return *self*, the lock would be acquired again
# in the *__enter__* method of the BaseFileLock, but not released again
# automatically.
#
# :seealso: issue #37 (memory leak)
class _Acquire_ReturnProxy(object):
def __init__(self, lock):
self.lock = lock
return None
def __enter__(self):
return self.lock
def __exit__(self, exc_type, exc_value, traceback):
self.lock.release()
return None
class BaseFileLock(object):
"""
Implements the base class of a file lock.
"""
def __init__(self, lock_file, timeout=-1):
""" """
# The path to the lock file.
self._lock_file = lock_file
# The file descriptor for the *_lock_file* as it is returned by the
# os.open() function.
# This file lock is only NOT None, if the object currently holds the
# lock.
self._lock_file_fd = None
# The default timeout value.
self.timeout = timeout
# We use this lock primarily for the lock counter.
self._thread_lock = threading.Lock()
# The lock counter is used for implementing the nested locking
# mechanism. Whenever the lock is acquired, the counter is increased and
# the lock is only released, when this value is 0 again.
self._lock_counter = 0
return None
@property
def lock_file(self):
"""
The path to the lock file.
"""
return self._lock_file
@property
def timeout(self):
"""
You can set a default timeout for the filelock. It will be used as
fallback value in the acquire method, if no timeout value (*None*) is
given.
If you want to disable the timeout, set it to a negative value.
A timeout of 0 means, that there is exactly one attempt to acquire the
file lock.
.. versionadded:: 2.0.0
"""
return self._timeout
@timeout.setter
def timeout(self, value):
""" """
self._timeout = float(value)
return None
# Platform dependent locking
# --------------------------------------------
def _acquire(self):
"""
Platform dependent. If the file lock could be
acquired, self._lock_file_fd holds the file descriptor
of the lock file.
"""
raise NotImplementedError()
def _release(self):
"""
Releases the lock and sets self._lock_file_fd to None.
"""
raise NotImplementedError()
# Platform independent methods
# --------------------------------------------
@property
def is_locked(self):
"""
True, if the object holds the file lock.
.. versionchanged:: 2.0.0
This was previously a method and is now a property.
"""
return self._lock_file_fd is not None
def acquire(self, timeout=None, poll_intervall=0.05):
"""
Acquires the file lock or fails with a :exc:`Timeout` error.
.. code-block:: python
# You can use this method in the context manager (recommended)
with lock.acquire():
pass
# Or use an equivalent try-finally construct:
lock.acquire()
try:
pass
finally:
lock.release()
:arg float timeout:
The maximum time waited for the file lock.
If ``timeout < 0``, there is no timeout and this method will
block until the lock could be acquired.
If ``timeout`` is None, the default :attr:`~timeout` is used.
:arg float poll_intervall:
We check once in *poll_intervall* seconds if we can acquire the
file lock.
:raises Timeout:
if the lock could not be acquired in *timeout* seconds.
.. versionchanged:: 2.0.0
This method returns now a *proxy* object instead of *self*,
so that it can be used in a with statement without side effects.
"""
# Use the default timeout, if no timeout is provided.
if timeout is None:
timeout = self.timeout
# Increment the number right at the beginning.
# We can still undo it, if something fails.
with self._thread_lock:
self._lock_counter += 1
lock_id = id(self)
lock_filename = self._lock_file
start_time = time.time()
try:
while True:
with self._thread_lock:
if not self.is_locked:
logger().debug("Attempting to acquire lock %s on %s", lock_id, lock_filename)
self._acquire()
if self.is_locked:
logger().info("Lock %s acquired on %s", lock_id, lock_filename)
break
elif timeout >= 0 and time.time() - start_time > timeout:
logger().debug("Timeout on acquiring lock %s on %s", lock_id, lock_filename)
raise Timeout(self._lock_file)
else:
logger().debug(
"Lock %s not acquired on %s, waiting %s seconds ...", lock_id, lock_filename, poll_intervall
)
time.sleep(poll_intervall)
except:
# Something did go wrong, so decrement the counter.
with self._thread_lock:
self._lock_counter = max(0, self._lock_counter - 1)
raise
return _Acquire_ReturnProxy(lock=self)
def release(self, force=False):
"""
Releases the file lock.
Please note, that the lock is only completly released, if the lock
counter is 0.
Also note, that the lock file itself is not automatically deleted.
:arg bool force:
If true, the lock counter is ignored and the lock is released in
every case.
"""
with self._thread_lock:
if self.is_locked:
self._lock_counter -= 1
if self._lock_counter == 0 or force:
lock_id = id(self)
lock_filename = self._lock_file
logger().debug("Attempting to release lock %s on %s", lock_id, lock_filename)
self._release()
self._lock_counter = 0
logger().info("Lock %s released on %s", lock_id, lock_filename)
return None
def __enter__(self):
self.acquire()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.release()
return None
def __del__(self):
self.release(force=True)
return None
# Windows locking mechanism
# ~~~~~~~~~~~~~~~~~~~~~~~~~
class WindowsFileLock(BaseFileLock):
"""
Uses the :func:`msvcrt.locking` function to hard lock the lock file on
windows systems.
"""
def _acquire(self):
open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
try:
fd = os.open(self._lock_file, open_mode)
except OSError:
pass
else:
try:
msvcrt.locking(fd, msvcrt.LK_NBLCK, 1)
except (IOError, OSError):
os.close(fd)
else:
self._lock_file_fd = fd
return None
def _release(self):
fd = self._lock_file_fd
self._lock_file_fd = None
msvcrt.locking(fd, msvcrt.LK_UNLCK, 1)
os.close(fd)
try:
os.remove(self._lock_file)
# Probably another instance of the application
# that acquired the file lock.
except OSError:
pass
return None
# Unix locking mechanism
# ~~~~~~~~~~~~~~~~~~~~~~
class UnixFileLock(BaseFileLock):
"""
Uses the :func:`fcntl.flock` to hard lock the lock file on unix systems.
"""
def _acquire(self):
open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
fd = os.open(self._lock_file, open_mode)
try:
fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
except (IOError, OSError):
os.close(fd)
else:
self._lock_file_fd = fd
return None
def _release(self):
# Do not remove the lockfile:
#
# https://github.com/benediktschmitt/py-filelock/issues/31
# https://stackoverflow.com/questions/17708885/flock-removing-locked-file-without-race-condition
fd = self._lock_file_fd
self._lock_file_fd = None
fcntl.flock(fd, fcntl.LOCK_UN)
os.close(fd)
return None
# Soft lock
# ~~~~~~~~~
class SoftFileLock(BaseFileLock):
"""
Simply watches the existence of the lock file.
"""
def _acquire(self):
open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC
try:
fd = os.open(self._lock_file, open_mode)
except (IOError, OSError):
pass
else:
self._lock_file_fd = fd
return None
def _release(self):
os.close(self._lock_file_fd)
self._lock_file_fd = None
try:
os.remove(self._lock_file)
# The file is already deleted and that's what we want.
except OSError:
pass
return None
# Platform filelock
# ~~~~~~~~~~~~~~~~~
#: Alias for the lock, which should be used for the current platform. On
#: Windows, this is an alias for :class:`WindowsFileLock`, on Unix for
#: :class:`UnixFileLock` and otherwise for :class:`SoftFileLock`.
FileLock = None
if msvcrt:
FileLock = WindowsFileLock
elif fcntl:
FileLock = UnixFileLock
else:
FileLock = SoftFileLock
if warnings is not None:
warnings.warn("only soft file lock is available")
from veros.core.operators import numpy as npx
import numpy as onp
import scipy.interpolate
import scipy.spatial
def interpolate(coords, var, interp_coords, missing_value=None, fill=True, kind="linear"):
"""Interpolate globally defined data to a different (regular) grid.
Arguments:
coords: Tuple of coordinate arrays for each dimension.
var (:obj:`ndarray` of dim (nx1, ..., nxd)): Variable data to interpolate.
interp_coords: Tuple of coordinate arrays to interpolate to.
missing_value (optional): Value denoting cells of missing data in ``var``.
Is replaced by `NaN` before interpolating. Defaults to `None`, which means
no replacement is taking place.
fill (bool, optional): Whether `NaN` values should be replaced by the nearest
finite value after interpolating. Defaults to ``True``.
kind (str, optional): Order of interpolation. Supported are `nearest` and
`linear` (default).
Returns:
:obj:`ndarray` containing the interpolated values on the grid spanned by
``interp_coords``.
"""
if len(coords) != len(interp_coords) or len(coords) != var.ndim:
raise ValueError("Dimensions of coordinates and values do not match")
if missing_value is not None:
invalid_mask = npx.isclose(var, missing_value)
var = npx.where(invalid_mask, npx.nan, var)
if var.ndim > 1 and coords[0].ndim == 1:
interp_grid = npx.rollaxis(npx.array(npx.meshgrid(*interp_coords, indexing="ij")), 0, len(interp_coords) + 1)
else:
interp_grid = interp_coords
def as_floatarray(x):
return onp.array(x, dtype="float64")
coords = tuple(as_floatarray(c) for c in coords)
var = scipy.interpolate.interpn(
coords, as_floatarray(var), as_floatarray(interp_grid), bounds_error=False, fill_value=onp.nan, method=kind
)
var = npx.asarray(var)
if fill:
var = fill_holes(var)
return var
def fill_holes(data):
"""A simple inpainting function that replaces NaN values in `data` with the
nearest finite value.
"""
data = onp.array(data)
dim = data.ndim
flag = ~onp.isnan(data)
slcs = [slice(None)] * dim
while onp.any(~flag):
for i in range(dim):
slcs1 = slcs[:]
slcs2 = slcs[:]
slcs1[i] = slice(0, -1)
slcs2[i] = slice(1, None)
slcs1 = tuple(slcs1)
slcs2 = tuple(slcs2)
# replace from the right
repmask = onp.logical_and(~flag[slcs1], flag[slcs2])
data[slcs1][repmask] = data[slcs2][repmask]
flag[slcs1][repmask] = True
# replace from the left
repmask = onp.logical_and(~flag[slcs2], flag[slcs1])
data[slcs2][repmask] = data[slcs1][repmask]
flag[slcs2][repmask] = True
return npx.asarray(data)
def get_periodic_interval(current_time, cycle_length, rec_spacing, n_rec):
"""Used for linear interpolation between periodic time intervals.
One common application is the interpolation of external forcings that are defined
at discrete times (e.g. one value per month of a standard year) to the current
time step.
Arguments:
current_time (float): Time to interpolate to.
cycle_length (float): Total length of one periodic cycle.
rec_spacing (float): Time spacing between each data record.
n_rec (int): Total number of records available.
Returns:
:obj:`tuple` containing (n1, f1), (n2, f2): Indices and weights for the interpolated
record array.
Example:
The following interpolates a record array ``data`` containing 12 monthly values
to the current time step:
>>> year_in_seconds = 60. * 60. * 24. * 365.
>>> current_time = 60. * 60. * 24. * 45. # mid-february
>>> print(data.shape)
(360, 180, 12)
>>> (n1, f1), (n2, f2) = get_periodic_interval(current_time, year_in_seconds, year_in_seconds / 12, 12)
>>> data_at_current_time = f1 * data[..., n1] + f2 * data[..., n2]
"""
current_time = current_time % cycle_length
# using npx.array works with both NumPy and JAX
t_idx_1 = npx.array(current_time // rec_spacing, dtype="int")
t_idx_2 = npx.array((1 + t_idx_1) % n_rec, dtype="int")
weight_2 = (current_time - rec_spacing * t_idx_1) / rec_spacing
weight_1 = 1.0 - weight_2
return (t_idx_1, weight_1), (t_idx_2, weight_2)
def make_cyclic(longitude, array=None, wrap=360.0):
"""Create a cyclic version of a longitude array and (optionally) another array.
Arguments:
longitude (ndarray): Longitude array of shape (nlon, ...).
array (ndarray): Another array that is to be made cyclic of shape (nlon, ...).
wrap (float): Wrapping value, defaults to 360 (degrees).
Returns:
Tuple containing (cyclic_longitudes, cyclic_array) if `array` is given, otherwise
just the ndarray cyclic_longitudes of shape (2 * nlon, ...).
"""
lonsize = longitude.shape[0]
cyclic_longitudes = npx.hstack(
(longitude[lonsize // 2 :, ...] - wrap, longitude, longitude[: lonsize // 2, ...] + wrap)
)
if array is None:
return cyclic_longitudes
cyclic_array = npx.hstack((array[lonsize // 2 :, ...], array, array[: lonsize // 2, ...]))
return cyclic_longitudes, cyclic_array
def get_coastline_distance(coords, coast_mask, spherical=False, radius=None, num_candidates=None):
"""Calculate the (approximate) distance of each water cell from the nearest coastline.
Arguments:
coords (tuple of ndarrays): Tuple containing x and y (longitude and latitude)
coordinate arrays of shape (nx, ny).
coast_mask (ndarray): Boolean mask indicating whether a cell is a land cell
(must be same shape as coordinate arrays).
spherical (bool): Use spherical instead of Cartesian coordinates.
When this is `True`, cyclical boundary conditions are used, and the
resulting distances are only approximate. Cells are pre-sorted by
Euclidean lon-lat distance, and great circle distances are calculated for
the first `num_candidates` elements. Defaults to `False`.
radius (float): Radius of spherical coordinate system. Must be given when
`spherical` is `True`.
num_candidates (int): Number of candidates to calculate great circle distances
for for each water cell. The higher this value, the more accurate the returned
distances become when `spherical` is `True`. Defaults to the square root
of the number of coastal cells.
Returns:
:obj:`ndarray` of shape (nx, ny) indicating the distance to the nearest land
cell (0 if cell is land).
Example:
The following returns coastal distances of all T cells for a spherical Veros setup.
>>> coords = npx.meshgrid(vs.xt[2:-2], vs.yt[2:-2], indexing='ij')
>>> dist = tools.get_coastline_distance(coords, vs.kbot > 0, spherical=True, radius=settings.radius)
"""
if not len(coords) == 2:
raise ValueError("coords must be lon-lat tuple")
if not all(c.shape == coast_mask.shape for c in coords):
raise ValueError("coordinates must have same shape as coastal mask")
if spherical and not radius:
raise ValueError("radius must be given for spherical coordinates")
watercoords = onp.array([c[~coast_mask] for c in coords]).T
if spherical:
coastcoords = onp.array(make_cyclic(coords[0][coast_mask], coords[1][coast_mask])).T
else:
coastcoords = onp.array((coords[0][coast_mask], coords[1][coast_mask])).T
coast_kdtree = scipy.spatial.cKDTree(coastcoords)
distance = onp.zeros(coords[0].shape)
if spherical:
def spherical_distance(coords1, coords2):
"""Calculate great circle distance from latitude and longitude"""
coords1 *= onp.pi / 180.0
coords2 *= onp.pi / 180.0
lon1, lon2, lat1, lat2 = coords1[..., 0], coords2[..., 0], coords1[..., 1], coords2[..., 1]
return radius * onp.arccos(
onp.sin(lat1) * onp.sin(lat2) + onp.cos(lat1) * onp.cos(lat2) * onp.cos(lon1 - lon2)
)
if not num_candidates:
num_candidates = int(onp.sqrt(onp.count_nonzero(~coast_mask)))
i_nearest = coast_kdtree.query(watercoords, k=num_candidates)[1]
approx_nearest = coastcoords[i_nearest]
distance[~coast_mask] = onp.min(spherical_distance(approx_nearest, watercoords[..., onp.newaxis, :]), axis=-1)
else:
distance[~coast_mask] = coast_kdtree.query(watercoords)[0]
return npx.asarray(distance)
def get_uniform_grid_steps(total_length, stepsize):
"""Get uniform grid step sizes in an interval.
Arguments:
total_length (float): total length of the resulting grid
stepsize (float): grid step size
Returns:
:obj:`ndarray` of grid steps
Example:
>>> uniform_steps = uniform_grid_setup(6., 0.25)
>>> uniform_steps
[ 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25,
0.25, 0.25, 0.25, 0.25, 0.25, 0.25 ]
"""
if total_length % stepsize:
raise ValueError("total length must be an integer multiple of stepsize")
return stepsize * npx.ones(int(total_length / stepsize))
def get_stretched_grid_steps(
n_cells, total_length, minimum_stepsize, stretching_factor=2.5, two_sided_grid=False, refine_towards="upper"
):
"""Computes stretched grid steps for regional and global domains with either
one or two-sided stretching using a hyperbolic tangent stretching function.
Arguments:
n_cells (int): Number of grid points.
total_length (float): Length of the grid interval to be covered (sum of the
resulting grid steps).
minimum_stepsize (float): Grid step size on the lower end of the interval.
stretching_factor (float, optional): Coefficient of the `tanh` stretching
function. The higher this value, the more abrupt the step sizes change.
two_sided_grid (bool, optional): If set to `True`, the resulting grid will be symmetrical
around the center. Defaults to `False`.
refine_towards ('upper' or 'lower', optional): The side of the interval that is to be refined.
Defaults to 'upper'.
Returns:
:obj:`ndarray` of shape `(n_cells)` containing grid steps.
Examples:
>>> dyt = get_stretched_grid_steps(14, 180, 5)
>>> dyt
[ 5.10517337 5.22522948 5.47813251 5.99673813 7.00386752
8.76808565 11.36450896 14.34977676 16.94620006 18.71041819
19.71754758 20.2361532 20.48905624 20.60911234]
>>> dyt.sum()
180.0
>>> dyt = get_stretched_grid_steps(14, 180, 5, stretching_factor=4.)
>>> dyt
[ 5.00526979 5.01802837 5.06155549 5.20877528 5.69251688
7.14225176 10.51307232 15.20121339 18.57203395 20.02176884
20.50551044 20.65273022 20.69625734 20.70901593]
>>> dyt.sum()
180.0
"""
if refine_towards not in ("upper", "lower"):
raise ValueError('refine_towards must be "upper" or "lower"')
if two_sided_grid:
if n_cells % 2:
raise ValueError(f"number of grid points must be even integer number (given: {n_cells})")
n_cells = n_cells / 2
stretching_function = npx.tanh(stretching_factor * npx.linspace(-1, 1, n_cells))
if refine_towards == "lower":
stretching_function = stretching_function[::-1]
if two_sided_grid:
stretching_function = npx.concatenate((stretching_function[::-1], stretching_function))
def normalize_sum(var, sum_value, minimum_value=0.0):
if abs(var.sum()) < 1e-5:
var += 1
var *= (sum_value - len(var) * minimum_value) / var.sum()
return var + minimum_value
stretching_function = normalize_sum(stretching_function, total_length, minimum_stepsize)
assert abs(1 - npx.sum(stretching_function) / total_length) < 1e-5, "precision error"
return stretching_function
def get_vinokur_grid_steps(
n_cells, total_length, lower_stepsize, upper_stepsize=None, two_sided_grid=False, refine_towards="upper"
):
"""Computes stretched grid steps for regional and global domains with either
one or two-sided stretching using Vinokur stretching.
This stretching function minimizes discretization errors on finite difference
grids.
Arguments:
n_cells (int): Number of grid points.
total_length (float): Length of the grid interval to be covered (sum of the
resulting grid steps).
lower_stepsize (float): Grid step size on the lower end of the interval.
upper_stepsize (float or ``None``, optional): Grid step size on the upper end of the interval.
If not given, the one-sided version of the algorithm is used (that enforces zero curvature
on the upper end).
two_sided_grid (bool, optional): If set to `True`, the resulting grid will be symmetrical
around the center. Defaults to `False`.
refine_towards ('upper' or 'lower', optional): The side of the interval that is to be refined.
Defaults to 'upper'.
Returns:
:obj:`ndarray` of shape `(n_cells)` containing grid steps.
Reference:
Vinokur, Marcel, On One-Dimensional Stretching Functions for Finite-Difference Calculations,
Journal of Computational Physics. 50, 215, 1983.
Examples:
>>> dyt = get_vinokur_grid_steps(14, 180, 5, two_sided_grid=True)
>>> dyt
[ 18.2451554 17.23915939 15.43744632 13.17358802 10.78720589
8.53852027 6.57892471 6.57892471 8.53852027 10.78720589
13.17358802 15.43744632 17.23915939 18.2451554 ]
>>> dyt.sum()
180.
>>> dyt = get_vinokur_grid_steps(14, 180, 5, upper_stepsize=10)
>>> dyt
[ 5.9818365 7.3645667 8.92544833 10.61326984 12.33841985
13.97292695 15.36197306 16.3485688 16.80714121 16.67536919
15.97141714 14.78881918 13.27136448 11.57887877 ]
>>> dyt.sum()
180.
"""
if refine_towards not in ("upper", "lower"):
raise ValueError('refine_towards must be "upper" or "lower"')
if two_sided_grid:
if n_cells % 2:
raise ValueError(f"number of grid points must be an even integer (given: {n_cells})")
n_cells = n_cells // 2
n_cells += 1
def approximate_sinc_inverse(y):
"""Approximate inverse of sin(y) / y"""
if y < 0.26938972:
inv = npx.pi * (
1
- y
+ y**2
- (1 + npx.pi**2 / 6) * y**3
+ 6.794732 * y**4
- 13.205501 * y**5
+ 11.726095 * y**6
)
else:
ybar = 1.0 - y
inv = npx.sqrt(6 * ybar) * (
1
+ 0.15 * ybar
+ 0.057321429 * ybar**2
+ 0.048774238 * ybar**3
- 0.053337753 * ybar**4
+ 0.075845134 * ybar**5
)
assert abs(1 - npx.sin(inv) / inv / y) < 1e-2, "precision error"
return inv
def approximate_sinhc_inverse(y):
"""Approximate inverse of sinh(y) / y"""
if y < 2.7829681:
ybar = y - 1.0
inv = npx.sqrt(6 * ybar) * (
1
- 0.15 * ybar
+ 0.057321429 * ybar**2
- 0.024907295 * ybar**3
+ 0.0077424461 * ybar**4
- 0.0010794123 * ybar**5
)
else:
v = npx.log(y)
w = 1.0 / y - 0.028527431
inv = (
v
+ (1 + 1.0 / v) * npx.log(2 * v)
- 0.02041793
+ 0.24902722 * w
+ 1.9496443 * w**2
- 2.6294547 * w**3
+ 8.56795911 * w**4
)
assert abs(1 - npx.sinh(inv) / inv / y) < 1e-2, "precision error"
return inv
target_sum = total_length
if two_sided_grid:
target_sum *= 0.5
s0 = float(target_sum) / float(lower_stepsize * n_cells)
if upper_stepsize:
s1 = float(target_sum) / float(upper_stepsize * n_cells)
a, b = npx.sqrt(s1 / s0), npx.sqrt(s1 * s0)
if b > 1:
stretching_factor = approximate_sinhc_inverse(b)
stretched_grid = 0.5 + 0.5 * npx.tanh(stretching_factor * npx.linspace(-0.5, 0.5, n_cells)) / npx.tanh(
0.5 * stretching_factor
)
else:
stretching_factor = approximate_sinc_inverse(b)
stretched_grid = 0.5 + 0.5 * npx.tan(stretching_factor * npx.linspace(-0.5, 0.5, n_cells)) / npx.tan(
0.5 * stretching_factor
)
stretched_grid = stretched_grid / (a + (1.0 - a) * stretched_grid)
else:
if s0 > 1:
stretching_factor = approximate_sinhc_inverse(s0) * 0.5
stretched_grid = 1 + npx.tanh(stretching_factor * npx.linspace(0.0, 1.0, n_cells)) / npx.tanh(
stretching_factor
)
else:
stretching_factor = approximate_sinc_inverse(s0) * 0.5
stretched_grid = 1 + npx.tan(stretching_factor * npx.linspace(0.0, 1.0, n_cells)) / npx.tan(
stretching_factor
)
stretched_grid_steps = npx.diff(stretched_grid * target_sum)
if refine_towards == "upper":
stretched_grid_steps = stretched_grid_steps[::-1]
if two_sided_grid:
stretched_grid_steps = npx.concatenate((stretched_grid_steps[::-1], stretched_grid_steps))
assert abs(1 - npx.sum(stretched_grid_steps) / total_length) < 1e-5, "precision error"
return stretched_grid_steps
from veros import runtime_settings
class Variable:
def __init__(
self,
name,
dims,
units="",
long_description="",
dtype=None,
time_dependent=True,
scale=1.0,
write_to_restart=False,
extra_attributes=None,
mask=None,
active=True,
initial=None,
):
self.name = name
self.dims = dims
self.units = units
self.long_description = long_description
self.dtype = dtype
self.time_dependent = time_dependent
self.scale = scale
self.write_to_restart = write_to_restart
self.active = active
self.initial = initial
self.get_mask = lambda settings, vs: None
if mask is not None:
if not callable(mask):
raise TypeError("mask argument has to be callable")
self.get_mask = mask
elif isinstance(dims, tuple):
if dims[:3] in DEFAULT_MASKS:
self.get_mask = DEFAULT_MASKS[dims[:3]]
elif dims[:2] in DEFAULT_MASKS:
self.get_mask = DEFAULT_MASKS[dims[:2]]
#: Additional attributes to be written in netCDF output
self.extra_attributes = extra_attributes or {}
def __repr__(self):
attr_str = []
for v in vars(self):
attr_str.append(f"{v}={getattr(self, v)}")
attr_str = ", ".join(attr_str)
return f"{self.__class__.__qualname__}({attr_str})"
# fill value for netCDF output (invalid floating data is replaced by this value)
FLOAT_FILL_VALUE = -1e18
#
XT = ("xt",)
XU = ("xu",)
YT = ("yt",)
YU = ("yu",)
ZT = ("zt",)
ZW = ("zw",)
T_HOR = ("xt", "yt")
U_HOR = ("xu", "yt")
V_HOR = ("xt", "yu")
ZETA_HOR = ("xu", "yu")
T_GRID = ("xt", "yt", "zt")
U_GRID = ("xu", "yt", "zt")
V_GRID = ("xt", "yu", "zt")
W_GRID = ("xt", "yt", "zw")
ZETA_GRID = ("xu", "yu", "zt")
TIMESTEPS = ("timesteps",)
ISLE = ("isle",)
TENSOR_COMP = ("tensor1", "tensor2")
# those are written to netCDF output by default
BASE_DIMENSIONS = XT + XU + YT + YU + ZT + ZW + ISLE
GHOST_DIMENSIONS = ("xt", "yt", "xu", "yu")
# these are the settings that are getting used to determine shapes
DIM_TO_SHAPE_VAR = {
"xt": "nx",
"xu": "nx",
"yt": "ny",
"yu": "ny",
"zt": "nz",
"zw": "nz",
"timesteps": 3,
"tensor1": 2,
"tensor2": 2,
"isle": 0,
}
DEFAULT_MASKS = {
T_HOR: lambda settings, vs: vs.maskT[:, :, -1],
U_HOR: lambda settings, vs: vs.maskU[:, :, -1],
V_HOR: lambda settings, vs: vs.maskV[:, :, -1],
ZETA_HOR: lambda settings, vs: vs.maskZ[:, :, -1],
T_GRID: lambda settings, vs: vs.maskT,
U_GRID: lambda settings, vs: vs.maskU,
V_GRID: lambda settings, vs: vs.maskV,
W_GRID: lambda settings, vs: vs.maskW,
ZETA_GRID: lambda settings, vs: vs.maskZ,
}
# custom mask for streamfunction
def _get_psi_mask(settings, vs):
if not settings.enable_streamfunction:
return vs.maskT[:, :, -1]
# eroded around the edges
return vs.maskZ[:, :, -1] | ~vs.isle_boundary_mask
def get_fill_value(dtype):
import numpy as onp
if onp.issubdtype(dtype, onp.floating):
return FLOAT_FILL_VALUE
return onp.iinfo(dtype).max
def get_shape(dimensions, grid, include_ghosts=True, local=True):
from veros.routines import CURRENT_CONTEXT
from veros.distributed import SCATTERED_DIMENSIONS
if grid is None:
return ()
px, py = runtime_settings.num_proc
grid_shapes = dict(dimensions)
if local and CURRENT_CONTEXT.is_dist_safe:
for pxi, dims in zip((px, py), SCATTERED_DIMENSIONS):
for dim in dims:
if dim not in grid_shapes:
continue
grid_shapes[dim] = grid_shapes[dim] // pxi
if include_ghosts:
for d in GHOST_DIMENSIONS:
if d in grid_shapes:
grid_shapes[d] += 4
shape = []
for grid_dim in grid:
if isinstance(grid_dim, int):
shape.append(grid_dim)
continue
if grid_dim not in grid_shapes:
raise ValueError(f"unrecognized dimension {grid_dim}")
shape.append(grid_shapes[grid_dim])
return tuple(shape)
def remove_ghosts(array, dims):
if dims is None:
# scalar
return array
ghost_mask = tuple(slice(2, -2) if dim in GHOST_DIMENSIONS else slice(None) for dim in dims)
return array[ghost_mask]
VARIABLES = {
# scalars
"tau": Variable(
"Index of current time step",
None,
"",
"Index of current time step",
dtype="int32",
initial=1,
write_to_restart=True,
),
"taup1": Variable(
"Index of next time step", None, "", "Index of next time step", dtype="int32", initial=2, write_to_restart=True
),
"taum1": Variable(
"Index of last time step", None, "", "Index of last time step", dtype="int32", initial=0, write_to_restart=True
),
"time": Variable(
"Current time",
None,
"",
"Current time",
write_to_restart=True,
),
"itt": Variable("Current iteration", None, "", "Current iteration", dtype="int32", initial=0),
# base variables
"dxt": Variable("Zonal T-grid spacing", XT, "m", "Zonal (x) spacing of T-grid point", time_dependent=False),
"dxu": Variable("Zonal U-grid spacing", XU, "m", "Zonal (x) spacing of U-grid point", time_dependent=False),
"dyt": Variable(
"Meridional T-grid spacing", YT, "m", "Meridional (y) spacing of T-grid point", time_dependent=False
),
"dyu": Variable(
"Meridional U-grid spacing", YU, "m", "Meridional (y) spacing of U-grid point", time_dependent=False
),
"zt": Variable(
"Vertical coordinate (T)",
ZT,
"m",
"Vertical coordinate",
time_dependent=False,
extra_attributes={"positive": "up"},
),
"zw": Variable(
"Vertical coordinate (W)",
ZW,
"m",
"Vertical coordinate",
time_dependent=False,
extra_attributes={"positive": "up"},
),
"dzt": Variable("Vertical spacing (T)", ZT, "m", "Vertical spacing", time_dependent=False),
"dzw": Variable("Vertical spacing (W)", ZW, "m", "Vertical spacing", time_dependent=False),
"cost": Variable("Metric factor (T)", YT, "1", "Metric factor for spherical coordinates", time_dependent=False),
"cosu": Variable("Metric factor (U)", YU, "1", "Metric factor for spherical coordinates", time_dependent=False),
"tantr": Variable("Metric factor", YT, "1", "Metric factor for spherical coordinates", time_dependent=False),
"coriolis_t": Variable(
"Coriolis frequency", T_HOR, "1/s", "Coriolis frequency at T grid point", time_dependent=False
),
"kbot": Variable(
"Index of deepest cell",
T_HOR,
"",
"Index of the deepest grid cell (counting from 1, 0 means all land)",
dtype="int32",
time_dependent=False,
),
"ht": Variable("Total depth (T)", T_HOR, "m", "Total depth of the water column", time_dependent=False),
"hu": Variable("Total depth (U)", U_HOR, "m", "Total depth of the water column", time_dependent=False),
"hv": Variable("Total depth (V)", V_HOR, "m", "Total depth of the water column", time_dependent=False),
"hur": Variable(
"Total depth (U), masked", U_HOR, "m", "Total depth of the water column (masked)", time_dependent=False
),
"hvr": Variable(
"Total depth (V), masked", V_HOR, "m", "Total depth of the water column (masked)", time_dependent=False
),
"beta": Variable(
"Change of Coriolis freq.", T_HOR, "1/(ms)", "Change of Coriolis frequency with latitude", time_dependent=False
),
"area_t": Variable("Area of T-box", T_HOR, "m^2", "Area of T-box", time_dependent=False),
"area_u": Variable("Area of U-box", U_HOR, "m^2", "Area of U-box", time_dependent=False),
"area_v": Variable("Area of V-box", V_HOR, "m^2", "Area of V-box", time_dependent=False),
"maskT": Variable(
"Mask for tracer points",
T_GRID,
"",
"Mask in physical space for tracer points",
time_dependent=False,
dtype="bool",
),
"maskU": Variable(
"Mask for U points",
U_GRID,
"",
"Mask in physical space for U points",
time_dependent=False,
dtype="bool",
),
"maskV": Variable(
"Mask for V points",
V_GRID,
"",
"Mask in physical space for V points",
time_dependent=False,
dtype="bool",
),
"maskW": Variable(
"Mask for W points",
W_GRID,
"",
"Mask in physical space for W points",
time_dependent=False,
dtype="bool",
),
"maskZ": Variable(
"Mask for Zeta points",
ZETA_GRID,
"",
"Mask in physical space for Zeta points",
time_dependent=False,
dtype="bool",
),
"rho": Variable(
"Density",
T_GRID + TIMESTEPS,
"kg/m^3",
"In-situ density anomaly, relative to the surface mean value of 1024 kg/m^3",
write_to_restart=True,
),
"prho": Variable(
"Potential density",
T_GRID,
"kg/m^3",
"Potential density anomaly, relative to the surface mean value of 1024 kg/m^3 "
"(identical to in-situ density anomaly for equation of state type 1, 2, and 4)",
),
"int_drhodT": Variable(
"Der. of dyn. enthalpy by temperature",
T_GRID + TIMESTEPS,
"kg / (m^2 deg C)",
"Partial derivative of dynamic enthalpy by temperature",
write_to_restart=True,
),
"int_drhodS": Variable(
"Der. of dyn. enthalpy by salinity",
T_GRID + TIMESTEPS,
"kg / (m^2 g / kg)",
"Partial derivative of dynamic enthalpy by salinity",
write_to_restart=True,
),
"Nsqr": Variable(
"Square of stability frequency",
W_GRID + TIMESTEPS,
"1/s^2",
"Square of stability frequency",
write_to_restart=True,
),
"Hd": Variable("Dynamic enthalpy", T_GRID + TIMESTEPS, "m^2/s^2", "Dynamic enthalpy", write_to_restart=True),
"dHd": Variable(
"Change of dyn. enth. by adv.",
T_GRID + TIMESTEPS,
"m^2/s^3",
"Change of dynamic enthalpy due to advection",
write_to_restart=True,
),
"temp": Variable("Temperature", T_GRID + TIMESTEPS, "deg C", "Conservative temperature", write_to_restart=True),
"dtemp": Variable(
"Temperature tendency",
T_GRID + TIMESTEPS,
"deg C/s",
"Conservative temperature tendency",
write_to_restart=True,
),
"salt": Variable("Salinity", T_GRID + TIMESTEPS, "g/kg", "Salinity", write_to_restart=True),
"dsalt": Variable("Salinity tendency", T_GRID + TIMESTEPS, "g/(kg s)", "Salinity tendency", write_to_restart=True),
"dtemp_vmix": Variable(
"Change of temp. by vertical mixing",
T_GRID,
"deg C/s",
"Change of temperature due to vertical mixing",
),
"dtemp_hmix": Variable(
"Change of temp. by horizontal mixing",
T_GRID,
"deg C/s",
"Change of temperature due to horizontal mixing",
),
"dsalt_vmix": Variable(
"Change of sal. by vertical mixing",
T_GRID,
"deg C/s",
"Change of salinity due to vertical mixing",
),
"dsalt_hmix": Variable(
"Change of sal. by horizontal mixing",
T_GRID,
"deg C/s",
"Change of salinity due to horizontal mixing",
),
"dtemp_iso": Variable(
"Change of temp. by isop. mixing",
T_GRID,
"deg C/s",
"Change of temperature due to isopycnal mixing plus skew mixing",
),
"dsalt_iso": Variable(
"Change of sal. by isop. mixing",
T_GRID,
"deg C/s",
"Change of salinity due to isopycnal mixing plus skew mixing",
),
"forc_temp_surface": Variable(
"Surface temperature flux",
T_HOR,
"m deg C/s",
"Surface temperature flux",
),
"forc_salt_surface": Variable(
"Surface salinity flux",
T_HOR,
"m g/s kg",
"Surface salinity flux",
),
"u": Variable("Zonal velocity", U_GRID + TIMESTEPS, "m/s", "Zonal velocity", write_to_restart=True),
"v": Variable("Meridional velocity", V_GRID + TIMESTEPS, "m/s", "Meridional velocity", write_to_restart=True),
"w": Variable("Vertical velocity", W_GRID + TIMESTEPS, "m/s", "Vertical velocity", write_to_restart=True),
"du": Variable(
"Zonal velocity tendency", U_GRID + TIMESTEPS, "m/s", "Zonal velocity tendency", write_to_restart=True
),
"dv": Variable(
"Meridional velocity tendency", V_GRID + TIMESTEPS, "m/s", "Meridional velocity tendency", write_to_restart=True
),
"du_cor": Variable("Change of u by Coriolis force", U_GRID, "m/s^2", "Change of u due to Coriolis force"),
"dv_cor": Variable("Change of v by Coriolis force", V_GRID, "m/s^2", "Change of v due to Coriolis force"),
"du_mix": Variable(
"Change of u by vertical mixing", U_GRID, "m/s^2", "Change of u due to implicit vertical mixing"
),
"dv_mix": Variable(
"Change of v by vertical mixing", V_GRID, "m/s^2", "Change of v due to implicit vertical mixing"
),
"du_adv": Variable("Change of u by advection", U_GRID, "m/s^2", "Change of u due to advection"),
"dv_adv": Variable("Change of v by advection", V_GRID, "m/s^2", "Change of v due to advection"),
"p_hydro": Variable("Hydrostatic pressure", T_GRID, "m^2/s^2", "Hydrostatic pressure"),
"kappaM": Variable("Vertical viscosity", T_GRID, "m^2/s", "Vertical viscosity"),
"kappaH": Variable("Vertical diffusivity", W_GRID, "m^2/s", "Vertical diffusivity"),
"surface_taux": Variable(
"Surface wind stress",
U_HOR,
"N/m^2",
"Zonal surface wind stress",
),
"surface_tauy": Variable(
"Surface wind stress",
V_HOR,
"N/m^2",
"Meridional surface wind stress",
),
"forc_rho_surface": Variable("Surface density flux", T_HOR, "kg / (m^2 s)", "Surface potential density flux"),
"ssh": Variable(
"Sea surface height",
T_HOR,
"m",
"Sea surface height",
active=lambda settings: not settings.enable_streamfunction,
),
"psi": Variable(
lambda settings: "Streamfunction" if settings.enable_streamfunction else "Surface pressure",
lambda settings: ZETA_HOR + TIMESTEPS if settings.enable_streamfunction else T_HOR + TIMESTEPS,
lambda settings: "m^3/s" if settings.enable_streamfunction else "m^2/s^2",
lambda settings: "Barotropic streamfunction" if settings.enable_streamfunction else "Surface pressure",
write_to_restart=True,
mask=_get_psi_mask,
),
"dpsi": Variable(
"Streamfunction tendency",
ZETA_HOR + TIMESTEPS,
"m^3/s^2",
"Streamfunction tendency",
write_to_restart=True,
active=lambda settings: settings.enable_streamfunction,
),
"land_map": Variable(
"Land map",
T_HOR,
"",
"Land map",
dtype="int32",
active=lambda settings: settings.enable_streamfunction,
),
"isle": Variable(
"Island number",
ISLE,
"",
"Island number",
dtype="int32",
active=lambda settings: settings.enable_streamfunction,
),
"psin": Variable(
"Boundary streamfunction",
ZETA_HOR + ISLE,
"m^3/s",
"Boundary streamfunction",
time_dependent=False,
mask=_get_psi_mask,
active=lambda settings: settings.enable_streamfunction,
),
"dpsin": Variable(
"Boundary streamfunction factor",
ISLE + TIMESTEPS,
"m^3/s^2",
"Boundary streamfunction factor",
write_to_restart=True,
active=lambda settings: settings.enable_streamfunction,
),
"line_psin": Variable(
"Boundary line integrals",
ISLE + ISLE,
"m^4/s^2",
"Boundary line integrals",
time_dependent=False,
active=lambda settings: settings.enable_streamfunction,
),
"isle_boundary_mask": Variable(
"Island boundary mask",
T_HOR,
"",
"Island boundary mask",
time_dependent=False,
dtype="bool",
active=lambda settings: settings.enable_streamfunction,
),
"line_dir_south_mask": Variable(
"Line integral mask",
T_HOR + ISLE,
"",
"Line integral mask",
time_dependent=False,
dtype="bool",
active=lambda settings: settings.enable_streamfunction,
),
"line_dir_north_mask": Variable(
"Line integral mask",
T_HOR + ISLE,
"",
"Line integral mask",
time_dependent=False,
dtype="bool",
active=lambda settings: settings.enable_streamfunction,
),
"line_dir_east_mask": Variable(
"Line integral mask",
T_HOR + ISLE,
"",
"Line integral mask",
time_dependent=False,
dtype="bool",
active=lambda settings: settings.enable_streamfunction,
),
"line_dir_west_mask": Variable(
"Line integral mask",
T_HOR + ISLE,
"",
"Line integral mask",
time_dependent=False,
dtype="bool",
active=lambda settings: settings.enable_streamfunction,
),
"K_gm": Variable("Skewness diffusivity", W_GRID, "m^2/s", "GM diffusivity, either constant or from EKE model"),
"K_iso": Variable("Isopycnal diffusivity", W_GRID, "m^2/s", "Along-isopycnal diffusivity"),
"K_diss_v": Variable(
"Dissipation of kinetic Energy",
W_GRID,
"m^2/s^3",
"Kinetic energy dissipation by vertical, rayleigh and bottom friction",
write_to_restart=True,
),
"K_diss_bot": Variable(
"Dissipation of kinetic Energy", W_GRID, "m^2/s^3", "Mean energy dissipation by bottom and rayleigh friction"
),
"K_diss_h": Variable(
"Dissipation of kinetic Energy", W_GRID, "m^2/s^3", "Kinetic energy dissipation by horizontal friction"
),
"K_diss_gm": Variable(
"Dissipation of mean energy",
W_GRID,
"m^2/s^3",
"Mean energy dissipation by GM (TRM formalism only)",
),
"P_diss_v": Variable(
"Dissipation of potential Energy", W_GRID, "m^2/s^3", "Potential energy dissipation by vertical diffusion"
),
"P_diss_nonlin": Variable(
"Dissipation of potential Energy",
W_GRID,
"m^2/s^3",
"Potential energy dissipation by nonlinear equation of state",
),
"P_diss_iso": Variable(
"Dissipation of potential Energy", W_GRID, "m^2/s^3", "Potential energy dissipation by isopycnal mixing"
),
"P_diss_skew": Variable(
"Dissipation of potential Energy", W_GRID, "m^2/s^3", "Potential energy dissipation by GM (w/o TRM)"
),
"P_diss_hmix": Variable(
"Dissipation of potential Energy", W_GRID, "m^2/s^3", "Potential energy dissipation by horizontal mixing"
),
"P_diss_adv": Variable(
"Dissipation of potential Energy", W_GRID, "m^2/s^3", "Potential energy dissipation by advection"
),
"P_diss_sources": Variable(
"Dissipation of potential Energy",
W_GRID,
"m^2/s^3",
"Potential energy dissipation by external sources (e.g. restoring zones)",
),
"u_wgrid": Variable("U on W grid", W_GRID, "m/s", "Zonal velocity interpolated to W grid points"),
"v_wgrid": Variable("V on W grid", W_GRID, "m/s", "Meridional velocity interpolated to W grid points"),
"w_wgrid": Variable("W on W grid", W_GRID, "m/s", "Vertical velocity interpolated to W grid points"),
"xt": Variable(
"Zonal coordinate (T)",
XT,
lambda settings: "degrees_east" if settings.coord_degree else "km",
"Zonal (x) coordinate of T-grid point",
time_dependent=False,
scale=lambda settings: 1 if settings.coord_degree else 1e-3,
),
"xu": Variable(
"Zonal coordinate (U)",
XU,
lambda settings: "degrees_east" if settings.coord_degree else "km",
"Zonal (x) coordinate of U-grid point",
time_dependent=False,
scale=lambda settings: 1 if settings.coord_degree else 1e-3,
),
"yt": Variable(
"Meridional coordinate (T)",
YT,
lambda settings: "degrees_north" if settings.coord_degree else "km",
"Meridional (y) coordinate of T-grid point",
time_dependent=False,
scale=lambda settings: 1 if settings.coord_degree else 1e-3,
),
"yu": Variable(
"Meridional coordinate (U)",
YU,
lambda settings: "degrees_north" if settings.coord_degree else "km",
"Meridional (y) coordinate of U-grid point",
time_dependent=False,
scale=lambda settings: 1 if settings.coord_degree else 1e-3,
),
"temp_source": Variable(
"Source of temperature",
T_GRID,
"K/s",
"Non-conservative source of temperature",
active=lambda settings: settings.enable_tempsalt_sources,
),
"salt_source": Variable(
"Source of salt",
T_GRID,
"g/(kg s)",
"Non-conservative source of salt",
active=lambda settings: settings.enable_tempsalt_sources,
),
"u_source": Variable(
"Source of zonal velocity",
U_GRID,
"m/s^2",
"Non-conservative source of zonal velocity",
active=lambda settings: settings.enable_momentum_sources,
),
"v_source": Variable(
"Source of meridional velocity",
V_GRID,
"m/s^2",
"Non-conservative source of meridional velocity",
active=lambda settings: settings.enable_momentum_sources,
),
"K_11": Variable(
"Isopycnal mixing coefficient",
T_GRID,
"m^2/s",
"Isopycnal mixing tensor component",
active=lambda settings: settings.enable_neutral_diffusion,
),
"K_22": Variable(
"Isopycnal mixing coefficient",
T_GRID,
"m^2/s",
"Isopycnal mixing tensor component",
active=lambda settings: settings.enable_neutral_diffusion,
),
"K_33": Variable(
"Isopycnal mixing coefficient",
T_GRID,
"m^2/s",
"Isopycnal mixing tensor component",
active=lambda settings: settings.enable_neutral_diffusion,
),
"Ai_ez": Variable(
"Isopycnal diffusion coefficient",
T_GRID + TENSOR_COMP,
"Vertical isopycnal diffusion coefficient on eastern face of T cell",
"1",
active=lambda settings: settings.enable_neutral_diffusion,
),
"Ai_nz": Variable(
"Isopycnal diffusion coefficient",
T_GRID + TENSOR_COMP,
"Vertical isopycnal diffusion coefficient on northern face of T cell",
"1",
active=lambda settings: settings.enable_neutral_diffusion,
),
"Ai_bx": Variable(
"Isopycnal diffusion coefficient",
T_GRID + TENSOR_COMP,
"Zonal isopycnal diffusion coefficient on bottom face of T cell",
"1",
active=lambda settings: settings.enable_neutral_diffusion,
),
"Ai_by": Variable(
"Isopycnal diffusion coefficient",
T_GRID + TENSOR_COMP,
"Meridional isopycnal diffusion coefficient on bottom face of T cell",
"1",
active=lambda settings: settings.enable_neutral_diffusion,
),
"B1_gm": Variable(
"Zonal component of GM streamfunction",
V_GRID,
"m^2/s",
"Zonal component of GM streamfunction",
active=lambda settings: settings.enable_skew_diffusion,
),
"B2_gm": Variable(
"Meridional component of GM streamfunction",
U_GRID,
"m^2/s",
"Meridional component of GM streamfunction",
active=lambda settings: settings.enable_skew_diffusion,
),
"r_bot_var_u": Variable(
"Bottom friction coeff.",
U_HOR,
"1/s",
"Zonal bottom friction coefficient",
active=lambda settings: settings.enable_bottom_friction_var,
),
"r_bot_var_v": Variable(
"Bottom friction coeff.",
V_HOR,
"1/s",
"Meridional bottom friction coefficient",
active=lambda settings: settings.enable_bottom_friction_var,
),
"kappa_gm": Variable(
"Vertical diffusivity",
W_GRID,
"m^2/s",
"Vertical diffusivity",
active=lambda settings: settings.enable_TEM_friction,
),
"tke": Variable(
"Turbulent kinetic energy",
W_GRID + TIMESTEPS,
"m^2/s^2",
"Turbulent kinetic energy",
write_to_restart=True,
active=lambda settings: settings.enable_tke,
),
"sqrttke": Variable(
"Square-root of TKE",
W_GRID,
"m/s",
"Square-root of TKE",
active=lambda settings: settings.enable_tke,
),
"dtke": Variable(
"Turbulent kinetic energy tendency",
W_GRID + TIMESTEPS,
"m^2/s^3",
"Turbulent kinetic energy tendency",
write_to_restart=True,
active=lambda settings: settings.enable_tke,
),
"Prandtlnumber": Variable(
"Prandtl number",
W_GRID,
"",
"Prandtl number",
active=lambda settings: settings.enable_tke,
),
"mxl": Variable(
"Mixing length",
W_GRID,
"m",
"Mixing length",
active=lambda settings: settings.enable_tke,
),
"forc_tke_surface": Variable(
"TKE surface flux",
T_HOR,
"m^3/s^3",
"TKE surface flux",
active=lambda settings: settings.enable_tke,
),
"tke_diss": Variable(
"TKE dissipation",
W_GRID,
"m^2/s^3",
"TKE dissipation",
active=lambda settings: settings.enable_tke,
),
"tke_surf_corr": Variable(
"Correction of TKE surface flux",
T_HOR,
"m^3/s^3",
"Correction of TKE surface flux",
active=lambda settings: settings.enable_tke,
),
"eke": Variable(
"meso-scale energy",
W_GRID + TIMESTEPS,
"m^2/s^2",
"meso-scale energy",
write_to_restart=True,
active=lambda settings: settings.enable_eke,
),
"deke": Variable(
"meso-scale energy tendency",
W_GRID + TIMESTEPS,
"m^2/s^3",
"meso-scale energy tendency",
write_to_restart=True,
active=lambda settings: settings.enable_eke,
),
"sqrteke": Variable(
"square-root of eke",
W_GRID,
"m/s",
"square-root of eke",
active=lambda settings: settings.enable_eke,
),
"L_rossby": Variable(
"Rossby radius",
T_HOR,
"m",
"Rossby radius",
active=lambda settings: settings.enable_eke,
),
"L_rhines": Variable(
"Rhines scale",
W_GRID,
"m",
"Rhines scale",
active=lambda settings: settings.enable_eke,
),
"eke_len": Variable(
"Eddy length scale",
T_GRID,
"m",
"Eddy length scale",
active=lambda settings: settings.enable_eke,
),
"eke_diss_iw": Variable(
"Dissipation of EKE to IW",
W_GRID,
"m^2/s^3",
"Dissipation of EKE to internal waves",
active=lambda settings: settings.enable_eke,
),
"eke_diss_tke": Variable(
"Dissipation of EKE to TKE",
W_GRID,
"m^2/s^3",
"Dissipation of EKE to TKE",
active=lambda settings: settings.enable_eke,
),
"E_iw": Variable(
"Internal wave energy",
W_GRID + TIMESTEPS,
"m^2/s^2",
"Internal wave energy",
write_to_restart=True,
active=lambda settings: settings.enable_idemix,
),
"dE_iw": Variable(
"Internal wave energy tendency",
W_GRID + TIMESTEPS,
"m^2/s^2",
"Internal wave energy tendency",
write_to_restart=True,
active=lambda settings: settings.enable_idemix,
),
"c0": Variable(
"Vertical IW group velocity",
W_GRID,
"m/s",
"Vertical internal wave group velocity",
active=lambda settings: settings.enable_idemix,
),
"v0": Variable(
"Horizontal IW group velocity",
W_GRID,
"m/s",
"Horizontal internal wave group velocity",
active=lambda settings: settings.enable_idemix,
),
"alpha_c": Variable(
"?",
W_GRID,
"?",
"?",
active=lambda settings: settings.enable_idemix,
),
"iw_diss": Variable(
"IW dissipation",
W_GRID,
"m^2/s^3",
"Internal wave dissipation",
active=lambda settings: settings.enable_idemix,
),
"forc_iw_surface": Variable(
"IW surface forcing",
T_HOR,
"m^3/s^3",
"Internal wave surface forcing",
time_dependent=False,
active=lambda settings: settings.enable_idemix,
),
"forc_iw_bottom": Variable(
"IW bottom forcing",
T_HOR,
"m^3/s^3",
"Internal wave bottom forcing",
time_dependent=False,
active=lambda settings: settings.enable_idemix,
),
}
def manifest_metadata(var_meta, settings):
"""Evaluate callable metadata fields given the current settings."""
from copy import copy
out = {}
for var_name, var_val in var_meta.items():
var_val = copy(var_val)
for attr, attr_val in vars(var_val).items():
if callable(attr_val) and attr != "get_mask":
setattr(var_val, attr, attr_val(settings))
out[var_name] = var_val
return out
def allocate(dimensions, grid, dtype=None, include_ghosts=True, local=True, fill=0):
from veros.core.operators import numpy as npx
if dtype is None:
dtype = runtime_settings.float_type
shape = get_shape(dimensions, grid, include_ghosts=include_ghosts, local=local)
out = npx.full(shape, fill, dtype=dtype)
if runtime_settings.backend == "numpy":
out.flags.writeable = False
return out
import abc
# do not import veros.core here!
from veros import settings, time, signals, distributed, progress, runtime_settings as rs, logger
from veros.state import get_default_state
from veros.plugins import load_plugin
from veros.routines import veros_routine, is_veros_routine
from veros.timer import timer_context
class VerosSetup(metaclass=abc.ABCMeta):
"""Main class for Veros, used for building a model and running it.
Note:
This class is meant to be subclassed. Subclasses need to implement the
methods :meth:`set_parameter`, :meth:`set_topography`, :meth:`set_grid`,
:meth:`set_coriolis`, :meth:`set_initial_conditions`, :meth:`set_forcing`,
:meth:`set_diagnostics`, and :meth:`after_timestep`.
Example:
>>> import matplotlib.pyplot as plt
>>> from veros import VerosSetup
>>>
>>> class MyModel(VerosSetup):
>>> ...
>>>
>>> simulation = MyModel()
>>> simulation.run()
>>> plt.imshow(simulation.state.variables.psi[..., 0])
>>> plt.show()
"""
__veros_plugins__ = tuple()
def __init__(self, override=None):
self.override_settings = override or {}
# this should be the first time the core routines are imported
import veros.core # noqa: F401
self._plugin_interfaces = tuple(load_plugin(p) for p in self.__veros_plugins__)
self._setup_done = False
self.state = get_default_state(plugin_interfaces=self._plugin_interfaces)
@abc.abstractmethod
def set_parameter(self, state):
"""To be implemented by subclass.
First function to be called during setup.
Use this to modify the model settings.
Example:
>>> def set_parameter(self, state):
>>> settings = state.settings
>>> settings.nx, settings.ny, settings.nz = (360, 120, 50)
>>> settings.coord_degree = True
>>> settings.enable_cyclic = True
"""
pass
@abc.abstractmethod
def set_initial_conditions(self, state):
"""To be implemented by subclass.
May be used to set initial conditions.
Example:
>>> @veros_method
>>> def set_initial_conditions(self, state):
>>> vs = state.variables
>>> vs.u = update(vs.u, at[:, :, :, vs.tau], npx.random.rand(vs.u.shape[:-1]))
"""
pass
@abc.abstractmethod
def set_grid(self, state):
"""To be implemented by subclass.
Has to set the grid spacings :attr:`dxt`, :attr:`dyt`, and :attr:`dzt`,
along with the coordinates of the grid origin, :attr:`x_origin` and
:attr:`y_origin`.
Example:
>>> @veros_method
>>> def set_grid(self, state):
>>> vs = state.variables
>>> vs.x_origin, vs.y_origin = 0, 0
>>> vs.dxt = [0.1, 0.05, 0.025, 0.025, 0.05, 0.1]
>>> vs.dyt = 1.
>>> vs.dzt = [10, 10, 20, 50, 100, 200]
"""
pass
@abc.abstractmethod
def set_coriolis(self, state):
"""To be implemented by subclass.
Has to set the Coriolis parameter :attr:`coriolis_t` at T grid cells.
Example:
>>> @veros_method
>>> def set_coriolis(self, state):
>>> vs = state.variables
>>> vs.coriolis_t = 2 * vs.omega * npx.sin(vs.yt[npx.newaxis, :] / 180. * vs.pi)
"""
pass
@abc.abstractmethod
def set_topography(self, state):
"""To be implemented by subclass.
Must specify the model topography by setting :attr:`kbot`.
Example:
>>> @veros_method
>>> def set_topography(self, state):
>>> vs = state.variables
>>> vs.kbot = update(vs.kbot, at[...], 10)
>>> # add a rectangular island somewhere inside the domain
>>> vs.kbot = update(vs.kbot, at[10:20, 10:20], 0)
"""
pass
@abc.abstractmethod
def set_forcing(self, state):
"""To be implemented by subclass.
Called before every time step to update the external forcing, e.g. through
:attr:`forc_temp_surface`, :attr:`forc_salt_surface`, :attr:`surface_taux`,
:attr:`surface_tauy`, :attr:`forc_tke_surface`, :attr:`temp_source`, or
:attr:`salt_source`. Use this method to implement time-dependent forcing.
Example:
>>> @veros_method
>>> def set_forcing(self, state):
>>> vs = state.variables
>>> current_month = (vs.time / (31 * 24 * 60 * 60)) % 12
>>> vs.surface_taux = vs._windstress_data[:, :, current_month]
"""
pass
@abc.abstractmethod
def set_diagnostics(self, state):
"""To be implemented by subclass.
Called before setting up the :ref:`diagnostics <diagnostics>`. Use this method e.g. to
mark additional :ref:`variables <variables>` for output.
Example:
>>> @veros_method
>>> def set_diagnostics(self, state):
>>> state.diagnostics['snapshot'].output_variables += ['drho', 'dsalt', 'dtemp']
"""
pass
@abc.abstractmethod
def after_timestep(self, state):
"""Called at the end of each time step. Can be used to define custom, setup-specific
events.
"""
pass
def _ensure_setup_done(self):
if not self._setup_done:
raise RuntimeError("setup() method has to be called before running the model")
def setup(self):
from veros import diagnostics, restart
from veros.core import numerics, external, isoneutral
setup_funcs = (
self.set_parameter,
self.set_grid,
self.set_coriolis,
self.set_topography,
self.set_initial_conditions,
self.set_diagnostics,
self.set_forcing,
self.after_timestep,
)
for f in setup_funcs:
if not is_veros_routine(f):
raise RuntimeError(
f"{f.__name__} method is not a Veros routine. Please make sure to decorate it "
"with @veros_routine and try again."
)
logger.info("Running model setup")
with self.state.timers["setup"]:
with self.state.settings.unlock():
self.set_parameter(self.state)
for setting, value in self.override_settings.items():
setattr(self.state.settings, setting, value)
settings.check_setting_conflicts(self.state.settings)
distributed.validate_decomposition(self.state.dimensions)
self.state.initialize_variables()
self.state.diagnostics.update(diagnostics.create_default_diagnostics(self.state))
for plugin in self._plugin_interfaces:
for diagnostic in plugin.diagnostics:
self.state.diagnostics[diagnostic.name] = diagnostic()
self.set_grid(self.state)
numerics.calc_grid(self.state)
self.set_coriolis(self.state)
numerics.calc_beta(self.state)
self.set_topography(self.state)
numerics.calc_topo(self.state)
self.set_initial_conditions(self.state)
numerics.calc_initial_conditions(self.state)
if self.state.settings.enable_streamfunction:
external.streamfunction_init(self.state)
for plugin in self._plugin_interfaces:
plugin.setup_entrypoint(self.state)
self.set_diagnostics(self.state)
diagnostics.initialize(self.state)
restart.read_restart(self.state)
self.set_forcing(self.state)
isoneutral.check_isoneutral_slope_crit(self.state)
self._setup_done = True
@veros_routine
def step(self, state):
from veros import diagnostics, restart
from veros.core import idemix, eke, tke, momentum, thermodynamics, advection, utilities, isoneutral, numerics
self._ensure_setup_done()
vs = state.variables
settings = state.settings
with state.timers["diagnostics"]:
restart.write_restart(state)
with state.timers["main"]:
with state.timers["forcing"]:
self.set_forcing(state)
if state.settings.enable_idemix:
with state.timers["idemix"]:
idemix.set_idemix_parameter(state)
with state.timers["eke"]:
eke.set_eke_diffusivities(state)
with state.timers["tke"]:
tke.set_tke_diffusivities(state)
with state.timers["momentum"]:
momentum.momentum(state)
with state.timers["thermodynamics"]:
thermodynamics.thermodynamics(state)
if settings.enable_eke or settings.enable_tke or settings.enable_idemix:
with state.timers["advection"]:
advection.calculate_velocity_on_wgrid(state)
with state.timers["eke"]:
if state.settings.enable_eke:
eke.integrate_eke(state)
with state.timers["idemix"]:
if state.settings.enable_idemix:
idemix.integrate_idemix(state)
with state.timers["tke"]:
if state.settings.enable_tke:
tke.integrate_tke(state)
with state.timers["boundary_exchange"]:
vs.u = utilities.enforce_boundaries(vs.u, settings.enable_cyclic_x)
vs.v = utilities.enforce_boundaries(vs.v, settings.enable_cyclic_x)
if settings.enable_tke:
vs.tke = utilities.enforce_boundaries(vs.tke, settings.enable_cyclic_x)
if settings.enable_eke:
vs.eke = utilities.enforce_boundaries(vs.eke, settings.enable_cyclic_x)
if settings.enable_idemix:
vs.E_iw = utilities.enforce_boundaries(vs.E_iw, settings.enable_cyclic_x)
with state.timers["momentum"]:
momentum.vertical_velocity(state)
with state.timers["plugins"]:
for plugin in self._plugin_interfaces:
with state.timers[plugin.name]:
plugin.run_entrypoint(state)
vs.itt = vs.itt + 1
vs.time = vs.time + settings.dt_tracer
self.after_timestep(state)
with state.timers["diagnostics"]:
if not numerics.sanity_check(state):
raise RuntimeError(f"solution diverged at iteration {vs.itt}")
isoneutral.isoneutral_diag_streamfunction(state)
diagnostics.diagnose(state)
diagnostics.output(state)
# NOTE: benchmarks parse this, do not change / remove
logger.debug(" Time step took {:.2f}s", state.timers["main"].last_time)
# permutate time indices
vs.taum1, vs.tau, vs.taup1 = vs.tau, vs.taup1, vs.taum1
def run(self, show_progress_bar=None):
"""Main routine of the simulation.
Note:
Make sure to call :meth:`setup` prior to this function.
Arguments:
show_progress_bar (:obj:`bool`, optional): Whether to show fancy progress bar via tqdm.
By default, only show if stdout is a terminal and Veros is running on a single process.
"""
from veros import restart
self._ensure_setup_done()
vs = self.state.variables
settings = self.state.settings
time_length, time_unit = time.format_time(settings.runlen)
logger.info(f"\nStarting integration for {time_length:.1f} {time_unit}")
start_time = vs.time
# disable timers for first iteration
timer_context.active = False
pbar = progress.get_progress_bar(self.state, use_tqdm=show_progress_bar)
try:
with signals.signals_to_exception(), pbar:
while vs.time - start_time < settings.runlen:
self.step(self.state)
if not timer_context.active:
timer_context.active = True
pbar.advance_time(settings.dt_tracer)
except: # noqa: E722
logger.critical(f"Stopping integration at iteration {vs.itt}")
raise
else:
logger.success("Integration done\n")
finally:
restart.write_restart(self.state, force=True)
self._timing_summary()
def _timing_summary(self):
timing_summary = []
timing_summary.extend(
[
"",
"Timing summary:",
"(excluding first iteration)",
"---",
" setup time = {:.2f}s".format(self.state.timers["setup"].total_time),
" main loop time = {:.2f}s".format(self.state.timers["main"].total_time),
" forcing = {:.2f}s".format(self.state.timers["forcing"].total_time),
" momentum = {:.2f}s".format(self.state.timers["momentum"].total_time),
" pressure = {:.2f}s".format(self.state.timers["pressure"].total_time),
" friction = {:.2f}s".format(self.state.timers["friction"].total_time),
" thermodynamics = {:.2f}s".format(self.state.timers["thermodynamics"].total_time),
]
)
if rs.profile_mode:
timing_summary.extend(
[
" lateral mixing = {:.2f}s".format(self.state.timers["isoneutral"].total_time),
" vertical mixing = {:.2f}s".format(self.state.timers["vmix"].total_time),
" equation of state = {:.2f}s".format(self.state.timers["eq_of_state"].total_time),
]
)
timing_summary.extend(
[
" advection = {:.2f}s".format(self.state.timers["advection"].total_time),
" EKE = {:.2f}s".format(self.state.timers["eke"].total_time),
" IDEMIX = {:.2f}s".format(self.state.timers["idemix"].total_time),
" TKE = {:.2f}s".format(self.state.timers["tke"].total_time),
" boundary exchange = {:.2f}s".format(self.state.timers["boundary_exchange"].total_time),
" diagnostics and I/O = {:.2f}s".format(self.state.timers["diagnostics"].total_time),
" plugins = {:.2f}s".format(self.state.timers["plugins"].total_time),
]
)
timing_summary.extend(
[
" {:<22} = {:.2f}s".format(plugin.name, self.state.timers[plugin.name].total_time)
for plugin in self._plugin_interfaces
]
)
logger.debug("\n".join(timing_summary))
if rs.profile_mode:
print_profile_summary(self.state.profile_timers, self.state.timers["main"].total_time)
def print_profile_summary(profile_timers, main_loop_time):
profile_timings = ["", "Profile timings:", "[total time spent (% of main loop)]", "---"]
maxwidth = max(len(k) for k in profile_timers.keys())
profile_format_string = "{{:<{}}} = {{:.2f}}s ({{:.2f}}%)".format(maxwidth)
main_loop_time = max(main_loop_time, 1e-8) # prevent division by 0
for name, timer in profile_timers.items():
this_time = timer.total_time
if this_time == 0:
continue
profile_timings.append(profile_format_string.format(name, this_time, 100 * this_time / main_loop_time))
logger.diagnostic("\n".join(profile_timings))
# Version: 0.28
"""The Versioneer - like a rocketeer, but for versions.
The Versioneer
==============
* like a rocketeer, but for versions!
* https://github.com/python-versioneer/python-versioneer
* Brian Warner
* License: Public Domain (Unlicense)
* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3
* [![Latest Version][pypi-image]][pypi-url]
* [![Build Status][travis-image]][travis-url]
This is a tool for managing a recorded version number in setuptools-based
python projects. The goal is to remove the tedious and error-prone "update
the embedded version string" step from your release process. Making a new
release should be as easy as recording a new tag in your version-control
system, and maybe making new tarballs.
## Quick Install
Versioneer provides two installation modes. The "classic" vendored mode installs
a copy of versioneer into your repository. The experimental build-time dependency mode
is intended to allow you to skip this step and simplify the process of upgrading.
### Vendored mode
* `pip install versioneer` to somewhere in your $PATH
* A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is
available, so you can also use `conda install -c conda-forge versioneer`
* add a `[tool.versioneer]` section to your `pyproject.toml` or a
`[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md))
* Note that you will need to add `tomli; python_version < "3.11"` to your
build-time dependencies if you use `pyproject.toml`
* run `versioneer install --vendor` in your source tree, commit the results
* verify version information with `python setup.py version`
### Build-time dependency mode
* `pip install versioneer` to somewhere in your $PATH
* A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is
available, so you can also use `conda install -c conda-forge versioneer`
* add a `[tool.versioneer]` section to your `pyproject.toml` or a
`[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md))
* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`)
to the `requires` key of the `build-system` table in `pyproject.toml`:
```toml
[build-system]
requires = ["setuptools", "versioneer[toml]"]
build-backend = "setuptools.build_meta"
```
* run `versioneer install --no-vendor` in your source tree, commit the results
* verify version information with `python setup.py version`
## Version Identifiers
Source trees come from a variety of places:
* a version-control system checkout (mostly used by developers)
* a nightly tarball, produced by build automation
* a snapshot tarball, produced by a web-based VCS browser, like github's
"tarball from tag" feature
* a release tarball, produced by "setup.py sdist", distributed through PyPI
Within each source tree, the version identifier (either a string or a number,
this tool is format-agnostic) can come from a variety of places:
* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows
about recent "tags" and an absolute revision-id
* the name of the directory into which the tarball was unpacked
* an expanded VCS keyword ($Id$, etc)
* a `_version.py` created by some earlier build step
For released software, the version identifier is closely related to a VCS
tag. Some projects use tag names that include more than just the version
string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool
needs to strip the tag prefix to extract the version identifier. For
unreleased software (between tags), the version identifier should provide
enough information to help developers recreate the same tree, while also
giving them an idea of roughly how old the tree is (after version 1.2, before
version 1.3). Many VCS systems can report a description that captures this,
for example `git describe --tags --dirty --always` reports things like
"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the
0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has
uncommitted changes).
The version identifier is used for multiple purposes:
* to allow the module to self-identify its version: `myproject.__version__`
* to choose a name and prefix for a 'setup.py sdist' tarball
## Theory of Operation
Versioneer works by adding a special `_version.py` file into your source
tree, where your `__init__.py` can import it. This `_version.py` knows how to
dynamically ask the VCS tool for version information at import time.
`_version.py` also contains `$Revision$` markers, and the installation
process marks `_version.py` to have this marker rewritten with a tag name
during the `git archive` command. As a result, generated tarballs will
contain enough information to get the proper version.
To allow `setup.py` to compute a version too, a `versioneer.py` is added to
the top level of your source tree, next to `setup.py` and the `setup.cfg`
that configures it. This overrides several distutils/setuptools commands to
compute the version when invoked, and changes `setup.py build` and `setup.py
sdist` to replace `_version.py` with a small static file that contains just
the generated version data.
## Installation
See [INSTALL.md](./INSTALL.md) for detailed installation instructions.
## Version-String Flavors
Code which uses Versioneer can learn about its version string at runtime by
importing `_version` from your main `__init__.py` file and running the
`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can
import the top-level `versioneer.py` and run `get_versions()`.
Both functions return a dictionary with different flavors of version
information:
* `['version']`: A condensed version string, rendered using the selected
style. This is the most commonly used value for the project's version
string. The default "pep440" style yields strings like `0.11`,
`0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section
below for alternative styles.
* `['full-revisionid']`: detailed revision identifier. For Git, this is the
full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac".
* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the
commit date in ISO 8601 format. This will be None if the date is not
available.
* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that
this is only accurate if run in a VCS checkout, otherwise it is likely to
be False or None
* `['error']`: if the version string could not be computed, this will be set
to a string describing the problem, otherwise it will be None. It may be
useful to throw an exception in setup.py if this is set, to avoid e.g.
creating tarballs with a version string of "unknown".
Some variants are more useful than others. Including `full-revisionid` in a
bug report should allow developers to reconstruct the exact code being tested
(or indicate the presence of local changes that should be shared with the
developers). `version` is suitable for display in an "about" box or a CLI
`--version` output: it can be easily compared against release notes and lists
of bugs fixed in various releases.
The installer adds the following text to your `__init__.py` to place a basic
version in `YOURPROJECT.__version__`:
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
## Styles
The setup.cfg `style=` configuration controls how the VCS information is
rendered into a version string.
The default style, "pep440", produces a PEP440-compliant string, equal to the
un-prefixed tag name for actual releases, and containing an additional "local
version" section with more detail for in-between builds. For Git, this is
TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags
--dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the
tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and
that this commit is two revisions ("+2") beyond the "0.11" tag. For released
software (exactly equal to a known tag), the identifier will only contain the
stripped tag, e.g. "0.11".
Other styles are available. See [details.md](details.md) in the Versioneer
source tree for descriptions.
## Debugging
Versioneer tries to avoid fatal errors: if something goes wrong, it will tend
to return a version of "0+unknown". To investigate the problem, run `setup.py
version`, which will run the version-lookup code in a verbose mode, and will
display the full contents of `get_versions()` (including the `error` string,
which may help identify what went wrong).
## Known Limitations
Some situations are known to cause problems for Versioneer. This details the
most significant ones. More can be found on Github
[issues page](https://github.com/python-versioneer/python-versioneer/issues).
### Subprojects
Versioneer has limited support for source trees in which `setup.py` is not in
the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are
two common reasons why `setup.py` might not be in the root:
* Source trees which contain multiple subprojects, such as
[Buildbot](https://github.com/buildbot/buildbot), which contains both
"master" and "slave" subprojects, each with their own `setup.py`,
`setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI
distributions (and upload multiple independently-installable tarballs).
* Source trees whose main purpose is to contain a C library, but which also
provide bindings to Python (and perhaps other languages) in subdirectories.
Versioneer will look for `.git` in parent directories, and most operations
should get the right version string. However `pip` and `setuptools` have bugs
and implementation details which frequently cause `pip install .` from a
subproject directory to fail to find a correct version string (so it usually
defaults to `0+unknown`).
`pip install --editable .` should work correctly. `setup.py install` might
work too.
Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in
some later version.
[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking
this issue. The discussion in
[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the
issue from the Versioneer side in more detail.
[pip PR#3176](https://github.com/pypa/pip/pull/3176) and
[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve
pip to let Versioneer work correctly.
Versioneer-0.16 and earlier only looked for a `.git` directory next to the
`setup.cfg`, so subprojects were completely unsupported with those releases.
### Editable installs with setuptools <= 18.5
`setup.py develop` and `pip install --editable .` allow you to install a
project into a virtualenv once, then continue editing the source code (and
test) without re-installing after every change.
"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a
convenient way to specify executable scripts that should be installed along
with the python package.
These both work as expected when using modern setuptools. When using
setuptools-18.5 or earlier, however, certain operations will cause
`pkg_resources.DistributionNotFound` errors when running the entrypoint
script, which must be resolved by re-installing the package. This happens
when the install happens with one version, then the egg_info data is
regenerated while a different version is checked out. Many setup.py commands
cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into
a different virtualenv), so this can be surprising.
[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes
this one, but upgrading to a newer version of setuptools should probably
resolve it.
## Updating Versioneer
To upgrade your project to a new release of Versioneer, do the following:
* install the new Versioneer (`pip install -U versioneer` or equivalent)
* edit `setup.cfg` and `pyproject.toml`, if necessary,
to include any new configuration settings indicated by the release notes.
See [UPGRADING](./UPGRADING.md) for details.
* re-run `versioneer install --[no-]vendor` in your source tree, to replace
`SRC/_version.py`
* commit any changed files
## Future Directions
This tool is designed to make it easily extended to other version-control
systems: all VCS-specific components are in separate directories like
src/git/ . The top-level `versioneer.py` script is assembled from these
components by running make-versioneer.py . In the future, make-versioneer.py
will take a VCS name as an argument, and will construct a version of
`versioneer.py` that is specific to the given VCS. It might also take the
configuration arguments that are currently provided manually during
installation by editing setup.py . Alternatively, it might go the other
direction and include code from all supported VCS systems, reducing the
number of intermediate scripts.
## Similar projects
* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time
dependency
* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of
versioneer
* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools
plugin
## License
To make Versioneer easier to embed, all its code is dedicated to the public
domain. The `_version.py` that it creates is also in the public domain.
Specifically, both are released under the "Unlicense", as described in
https://unlicense.org/.
[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg
[pypi-url]: https://pypi.python.org/pypi/versioneer/
[travis-image]:
https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg
[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer
"""
# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring
# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements
# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error
# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with
# pylint:disable=attribute-defined-outside-init,too-many-arguments
import configparser
import errno
import json
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Callable, Dict
import functools
have_tomllib = True
if sys.version_info >= (3, 11):
import tomllib
else:
try:
import tomli as tomllib
except ImportError:
have_tomllib = False
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
def get_root():
"""Get the project root directory.
We require that all commands are run from the project root, i.e. the
directory that contains setup.py, setup.cfg, and versioneer.py .
"""
root = os.path.realpath(os.path.abspath(os.getcwd()))
setup_py = os.path.join(root, "setup.py")
versioneer_py = os.path.join(root, "versioneer.py")
if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
# allow 'python path/to/setup.py COMMAND'
root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0])))
setup_py = os.path.join(root, "setup.py")
versioneer_py = os.path.join(root, "versioneer.py")
if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
err = ("Versioneer was unable to run the project root directory. "
"Versioneer requires setup.py to be executed from "
"its immediate directory (like 'python setup.py COMMAND'), "
"or in a way that lets it use sys.argv[0] to find the root "
"(like 'python path/to/setup.py COMMAND').")
raise VersioneerBadRootError(err)
try:
# Certain runtime workflows (setup.py install/develop in a setuptools
# tree) execute all dependencies in a single python process, so
# "versioneer" may be imported multiple times, and python's shared
# module-import table will cache the first one. So we can't use
# os.path.dirname(__file__), as that will find whichever
# versioneer.py was first imported, even in later projects.
my_path = os.path.realpath(os.path.abspath(__file__))
me_dir = os.path.normcase(os.path.splitext(my_path)[0])
vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])
if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals():
print("Warning: build in %s is using versioneer.py from %s"
% (os.path.dirname(my_path), versioneer_py))
except NameError:
pass
return root
def get_config_from_root(root):
"""Read the project setup.cfg file to determine Versioneer config."""
# This might raise OSError (if setup.cfg is missing), or
# configparser.NoSectionError (if it lacks a [versioneer] section), or
# configparser.NoOptionError (if it lacks "VCS="). See the docstring at
# the top of versioneer.py for instructions on writing your setup.cfg .
root = Path(root)
pyproject_toml = root / "pyproject.toml"
setup_cfg = root / "setup.cfg"
section = None
if pyproject_toml.exists() and have_tomllib:
try:
with open(pyproject_toml, 'rb') as fobj:
pp = tomllib.load(fobj)
section = pp['tool']['versioneer']
except (tomllib.TOMLDecodeError, KeyError):
pass
if not section:
parser = configparser.ConfigParser()
with open(setup_cfg) as cfg_file:
parser.read_file(cfg_file)
parser.get("versioneer", "VCS") # raise error if missing
section = parser["versioneer"]
cfg = VersioneerConfig()
cfg.VCS = section['VCS']
cfg.style = section.get("style", "")
cfg.versionfile_source = section.get("versionfile_source")
cfg.versionfile_build = section.get("versionfile_build")
cfg.tag_prefix = section.get("tag_prefix")
if cfg.tag_prefix in ("''", '""', None):
cfg.tag_prefix = ""
cfg.parentdir_prefix = section.get("parentdir_prefix")
cfg.verbose = section.get("verbose")
return cfg
class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
# these dictionaries contain VCS-specific tools
LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
"""Create decorator to mark a method as the handler of a VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
HANDLERS.setdefault(vcs, {})[method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None
popen_kwargs = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
popen_kwargs["startupinfo"] = startupinfo
for command in commands:
try:
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
process = subprocess.Popen([command] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None), **popen_kwargs)
break
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
if verbose:
print("unable to run %s" % dispcmd)
print(e)
return None, None
else:
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
stdout = process.communicate()[0].strip().decode()
if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
return None, process.returncode
return stdout, process.returncode
LONG_VERSION_PY['git'] = r'''
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
# This file is released into the public domain.
# Generated by versioneer-0.28
# https://github.com/python-versioneer/python-versioneer
"""Git implementation of _version.py."""
import errno
import os
import re
import subprocess
import sys
from typing import Callable, Dict
import functools
def get_keywords():
"""Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must
# each be defined on a line of their own. _version.py will just call
# get_keywords().
git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s"
git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s"
git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s"
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
return keywords
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
def get_config():
"""Create, populate and return the VersioneerConfig() object."""
# these strings are filled in when 'setup.py versioneer' creates
# _version.py
cfg = VersioneerConfig()
cfg.VCS = "git"
cfg.style = "%(STYLE)s"
cfg.tag_prefix = "%(TAG_PREFIX)s"
cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s"
cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s"
cfg.verbose = False
return cfg
class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
"""Create decorator to mark a method as the handler of a VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None
popen_kwargs = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
popen_kwargs["startupinfo"] = startupinfo
for command in commands:
try:
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
process = subprocess.Popen([command] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None), **popen_kwargs)
break
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
if verbose:
print("unable to run %%s" %% dispcmd)
print(e)
return None, None
else:
if verbose:
print("unable to find command, tried %%s" %% (commands,))
return None, None
stdout = process.communicate()[0].strip().decode()
if process.returncode != 0:
if verbose:
print("unable to run %%s (error)" %% dispcmd)
print("stdout was %%s" %% stdout)
return None, process.returncode
return stdout, process.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose):
"""Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both
the project name and a version string. We will also support searching up
two directory levels for an appropriately named parent directory
"""
rootdirs = []
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %%s but none started with prefix %%s" %%
(str(rootdirs), parentdir_prefix))
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs):
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords = {}
try:
with open(versionfile_abs, "r") as fobj:
for line in fobj:
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
except OSError:
pass
return keywords
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
if "refnames" not in keywords:
raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
# git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant
# datestamp. However we prefer "%%ci" (which expands to an "ISO-8601
# -like" string, which we must then edit to make compliant), because
# it's been around since git-1.5.3, and it's too difficult to
# discover which version we're using, or to work around using an
# older one.
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
refnames = keywords["refnames"].strip()
if refnames.startswith("$Format"):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %%d
# expansion behaves like git log --decorate=short and strips out the
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%%s', no digits" %% ",".join(refs - tags))
if verbose:
print("likely tags: %%s" %% ",".join(sorted(tags)))
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
if not re.match(r'\d', r):
continue
if verbose:
print("picking %%s" %% r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
@register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
expanded, and _version.py hasn't already been rewritten with a short
version string, meaning we're inside a checked out source tree.
"""
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
# GIT_DIR can interfere with correct operation of Versioneer.
# It may be intended to be passed to the Versioneer-versioned project,
# but that should not change where we get our version from.
env = os.environ.copy()
env.pop("GIT_DIR", None)
runner = functools.partial(runner, env=env)
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=not verbose)
if rc != 0:
if verbose:
print("Directory %%s not under git control" %% root)
raise NotThisMethod("'git rev-parse --git-dir' returned error")
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = runner(GITS, [
"describe", "--tags", "--dirty", "--always", "--long",
"--match", f"{tag_prefix}[[:digit:]]*"
], cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
pieces = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
branch_name = branch_name.strip()
if branch_name == "HEAD":
# If we aren't exactly on a branch, pick a branch which represents
# the current commit. If all else fails, we are on a branchless
# commit.
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
# --contains was added in git-1.5.4
if rc != 0 or branches is None:
raise NotThisMethod("'git branch --contains' returned error")
branches = branches.split("\n")
# Remove the first line if we're running detached
if "(" in branches[0]:
branches.pop(0)
# Strip off the leading "* " from the list of branches.
branches = [branch[2:] for branch in branches]
if "master" in branches:
branch_name = "master"
elif not branches:
branch_name = None
else:
# Pick the first branch that is returned. Good or bad.
branch_name = branches[0]
pieces["branch"] = branch_name
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
# look for -dirty suffix
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
# unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%%s'"
%% describe_out)
return pieces
# tag
full_tag = mo.group(1)
if not full_tag.startswith(tag_prefix):
if verbose:
fmt = "tag '%%s' doesn't start with prefix '%%s'"
print(fmt %% (full_tag, tag_prefix))
pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'"
%% (full_tag, tag_prefix))
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
# commit: short hex revision ID
pieces["short"] = mo.group(3)
else:
# HEX: no tags
pieces["closest-tag"] = None
out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
def plus_or_dot(pieces):
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"
def render_pep440(pieces):
"""Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
Exceptions:
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += plus_or_dot(pieces)
rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_branch(pieces):
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
The ".dev0" means not master branch. Note that .dev0 sorts backwards
(a feature branch will appear "older" than the master branch).
Exceptions:
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+untagged.%%d.g%%s" %% (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def pep440_split_post(ver):
"""Split pep440 version string at the post-release segment.
Returns the release segments before the post-release and the
post-release version number (or -1 if no post-release segment is present).
"""
vc = str.split(ver, ".post")
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
def render_pep440_pre(pieces):
"""TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
if pieces["distance"]:
# update the post release segment
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%%d" %% (pieces["distance"])
else:
# no commits, use the tag as the version
rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%%d" %% pieces["distance"]
return rendered
def render_pep440_post(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards
(a dirty tree will appear "older" than the corresponding clean one),
but you shouldn't be releasing software with -dirty anyways.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%%d" %% pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%%s" %% pieces["short"]
else:
# exception #1
rendered = "0.post%%d" %% pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += "+g%%s" %% pieces["short"]
return rendered
def render_pep440_post_branch(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
The ".dev0" means not master branch.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%%d" %% pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%%s" %% pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0.post%%d" %% pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+g%%s" %% pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%%d" %% pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
else:
# exception #1
rendered = "0.post%%d" %% pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
return rendered
def render_git_describe(pieces):
"""TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render_git_describe_long(pieces):
"""TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'.
The distance/hash is unconditional.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
if not style or style == "default":
style = "pep440" # the default
if style == "pep440":
rendered = render_pep440(pieces)
elif style == "pep440-branch":
rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
elif style == "pep440-post-branch":
rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
rendered = render_git_describe(pieces)
elif style == "git-describe-long":
rendered = render_git_describe_long(pieces)
else:
raise ValueError("unknown style '%%s'" %% style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
def get_versions():
"""Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some
# py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
# case we can only use expanded keywords.
cfg = get_config()
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
verbose)
except NotThisMethod:
pass
try:
root = os.path.realpath(__file__)
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for _ in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
return render(pieces, cfg.style)
except NotThisMethod:
pass
try:
if cfg.parentdir_prefix:
return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
except NotThisMethod:
pass
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None,
"error": "unable to compute version", "date": None}
'''
@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs):
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords = {}
try:
with open(versionfile_abs, "r") as fobj:
for line in fobj:
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
except OSError:
pass
return keywords
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
if "refnames" not in keywords:
raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
# -like" string, which we must then edit to make compliant), because
# it's been around since git-1.5.3, and it's too difficult to
# discover which version we're using, or to work around using an
# older one.
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
refnames = keywords["refnames"].strip()
if refnames.startswith("$Format"):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
# expansion behaves like git log --decorate=short and strips out the
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
print("likely tags: %s" % ",".join(sorted(tags)))
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
if not re.match(r'\d', r):
continue
if verbose:
print("picking %s" % r)
return {"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": None,
"date": date}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False, "error": "no suitable tags", "date": None}
@register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
expanded, and _version.py hasn't already been rewritten with a short
version string, meaning we're inside a checked out source tree.
"""
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
# GIT_DIR can interfere with correct operation of Versioneer.
# It may be intended to be passed to the Versioneer-versioned project,
# but that should not change where we get our version from.
env = os.environ.copy()
env.pop("GIT_DIR", None)
runner = functools.partial(runner, env=env)
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=not verbose)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
raise NotThisMethod("'git rev-parse --git-dir' returned error")
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = runner(GITS, [
"describe", "--tags", "--dirty", "--always", "--long",
"--match", f"{tag_prefix}[[:digit:]]*"
], cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
pieces = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
branch_name = branch_name.strip()
if branch_name == "HEAD":
# If we aren't exactly on a branch, pick a branch which represents
# the current commit. If all else fails, we are on a branchless
# commit.
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
# --contains was added in git-1.5.4
if rc != 0 or branches is None:
raise NotThisMethod("'git branch --contains' returned error")
branches = branches.split("\n")
# Remove the first line if we're running detached
if "(" in branches[0]:
branches.pop(0)
# Strip off the leading "* " from the list of branches.
branches = [branch[2:] for branch in branches]
if "master" in branches:
branch_name = "master"
elif not branches:
branch_name = None
else:
# Pick the first branch that is returned. Good or bad.
branch_name = branches[0]
pieces["branch"] = branch_name
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
# look for -dirty suffix
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[:git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
# unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
return pieces
# tag
full_tag = mo.group(1)
if not full_tag.startswith(tag_prefix):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
% (full_tag, tag_prefix))
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix):]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
# commit: short hex revision ID
pieces["short"] = mo.group(3)
else:
# HEX: no tags
pieces["closest-tag"] = None
out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
def do_vcs_install(versionfile_source, ipy):
"""Git-specific installation logic for Versioneer.
For Git, this means creating/changing .gitattributes to mark _version.py
for export-subst keyword substitution.
"""
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
files = [versionfile_source]
if ipy:
files.append(ipy)
if "VERSIONEER_PEP518" not in globals():
try:
my_path = __file__
if my_path.endswith((".pyc", ".pyo")):
my_path = os.path.splitext(my_path)[0] + ".py"
versioneer_file = os.path.relpath(my_path)
except NameError:
versioneer_file = "versioneer.py"
files.append(versioneer_file)
present = False
try:
with open(".gitattributes", "r") as fobj:
for line in fobj:
if line.strip().startswith(versionfile_source):
if "export-subst" in line.strip().split()[1:]:
present = True
break
except OSError:
pass
if not present:
with open(".gitattributes", "a+") as fobj:
fobj.write(f"{versionfile_source} export-subst\n")
files.append(".gitattributes")
run_command(GITS, ["add", "--"] + files)
def versions_from_parentdir(parentdir_prefix, root, verbose):
"""Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both
the project name and a version string. We will also support searching up
two directory levels for an appropriately named parent directory
"""
rootdirs = []
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
(str(rootdirs), parentdir_prefix))
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
SHORT_VERSION_PY = """
# This file was generated by 'versioneer.py' (0.28) from
# revision-control system data, or from the parent directory name of an
# unpacked source archive. Distribution tarballs contain a pre-generated copy
# of this file.
import json
version_json = '''
%s
''' # END VERSION_JSON
def get_versions():
return json.loads(version_json)
"""
def versions_from_file(filename):
"""Try to determine the version from _version.py if present."""
try:
with open(filename) as f:
contents = f.read()
except OSError:
raise NotThisMethod("unable to read _version.py")
mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON",
contents, re.M | re.S)
if not mo:
mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON",
contents, re.M | re.S)
if not mo:
raise NotThisMethod("no version_json in _version.py")
return json.loads(mo.group(1))
def write_to_version_file(filename, versions):
"""Write the given version number to the given _version.py file."""
os.unlink(filename)
contents = json.dumps(versions, sort_keys=True,
indent=1, separators=(",", ": "))
with open(filename, "w") as f:
f.write(SHORT_VERSION_PY % contents)
print("set %s to '%s'" % (filename, versions["version"]))
def plus_or_dot(pieces):
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"
def render_pep440(pieces):
"""Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
Exceptions:
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_branch(pieces):
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
The ".dev0" means not master branch. Note that .dev0 sorts backwards
(a feature branch will appear "older" than the master branch).
Exceptions:
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def pep440_split_post(ver):
"""Split pep440 version string at the post-release segment.
Returns the release segments before the post-release and the
post-release version number (or -1 if no post-release segment is present).
"""
vc = str.split(ver, ".post")
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
def render_pep440_pre(pieces):
"""TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
if pieces["distance"]:
# update the post release segment
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
# no commits, use the tag as the version
rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%d" % pieces["distance"]
return rendered
def render_pep440_post(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards
(a dirty tree will appear "older" than the corresponding clean one),
but you shouldn't be releasing software with -dirty anyways.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
return rendered
def render_pep440_post_branch(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
The ".dev0" means not master branch.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
return rendered
def render_git_describe(pieces):
"""TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render_git_describe_long(pieces):
"""TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'.
The distance/hash is unconditional.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None}
if not style or style == "default":
style = "pep440" # the default
if style == "pep440":
rendered = render_pep440(pieces)
elif style == "pep440-branch":
rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
elif style == "pep440-post-branch":
rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
rendered = render_git_describe(pieces)
elif style == "git-describe-long":
rendered = render_git_describe_long(pieces)
else:
raise ValueError("unknown style '%s'" % style)
return {"version": rendered, "full-revisionid": pieces["long"],
"dirty": pieces["dirty"], "error": None,
"date": pieces.get("date")}
class VersioneerBadRootError(Exception):
"""The project root directory is unknown or missing key files."""
def get_versions(verbose=False):
"""Get the project version from whatever source is available.
Returns dict with two keys: 'version' and 'full'.
"""
if "versioneer" in sys.modules:
# see the discussion in cmdclass.py:get_cmdclass()
del sys.modules["versioneer"]
root = get_root()
cfg = get_config_from_root(root)
assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg"
handlers = HANDLERS.get(cfg.VCS)
assert handlers, "unrecognized VCS '%s'" % cfg.VCS
verbose = verbose or cfg.verbose
assert cfg.versionfile_source is not None, \
"please set versioneer.versionfile_source"
assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix"
versionfile_abs = os.path.join(root, cfg.versionfile_source)
# extract version from first of: _version.py, VCS command (e.g. 'git
# describe'), parentdir. This is meant to work for developers using a
# source checkout, for users of a tarball created by 'setup.py sdist',
# and for users of a tarball/zipball created by 'git archive' or github's
# download-from-tag feature or the equivalent in other VCSes.
get_keywords_f = handlers.get("get_keywords")
from_keywords_f = handlers.get("keywords")
if get_keywords_f and from_keywords_f:
try:
keywords = get_keywords_f(versionfile_abs)
ver = from_keywords_f(keywords, cfg.tag_prefix, verbose)
if verbose:
print("got version from expanded keyword %s" % ver)
return ver
except NotThisMethod:
pass
try:
ver = versions_from_file(versionfile_abs)
if verbose:
print("got version from file %s %s" % (versionfile_abs, ver))
return ver
except NotThisMethod:
pass
from_vcs_f = handlers.get("pieces_from_vcs")
if from_vcs_f:
try:
pieces = from_vcs_f(cfg.tag_prefix, root, verbose)
ver = render(pieces, cfg.style)
if verbose:
print("got version from VCS %s" % ver)
return ver
except NotThisMethod:
pass
try:
if cfg.parentdir_prefix:
ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
if verbose:
print("got version from parentdir %s" % ver)
return ver
except NotThisMethod:
pass
if verbose:
print("unable to compute version")
return {"version": "0+unknown", "full-revisionid": None,
"dirty": None, "error": "unable to compute version",
"date": None}
def get_version():
"""Get the short version string for this project."""
return get_versions()["version"]
def get_cmdclass(cmdclass=None):
"""Get the custom setuptools subclasses used by Versioneer.
If the package uses a different cmdclass (e.g. one from numpy), it
should be provide as an argument.
"""
if "versioneer" in sys.modules:
del sys.modules["versioneer"]
# this fixes the "python setup.py develop" case (also 'install' and
# 'easy_install .'), in which subdependencies of the main project are
# built (using setup.py bdist_egg) in the same python process. Assume
# a main project A and a dependency B, which use different versions
# of Versioneer. A's setup.py imports A's Versioneer, leaving it in
# sys.modules by the time B's setup.py is executed, causing B to run
# with the wrong versioneer. Setuptools wraps the sub-dep builds in a
# sandbox that restores sys.modules to it's pre-build state, so the
# parent is protected against the child's "import versioneer". By
# removing ourselves from sys.modules here, before the child build
# happens, we protect the child from the parent's versioneer too.
# Also see https://github.com/python-versioneer/python-versioneer/issues/52
cmds = {} if cmdclass is None else cmdclass.copy()
# we add "version" to setuptools
from setuptools import Command
class cmd_version(Command):
description = "report generated version string"
user_options = []
boolean_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
vers = get_versions(verbose=True)
print("Version: %s" % vers["version"])
print(" full-revisionid: %s" % vers.get("full-revisionid"))
print(" dirty: %s" % vers.get("dirty"))
print(" date: %s" % vers.get("date"))
if vers["error"]:
print(" error: %s" % vers["error"])
cmds["version"] = cmd_version
# we override "build_py" in setuptools
#
# most invocation pathways end up running build_py:
# distutils/build -> build_py
# distutils/install -> distutils/build ->..
# setuptools/bdist_wheel -> distutils/install ->..
# setuptools/bdist_egg -> distutils/install_lib -> build_py
# setuptools/install -> bdist_egg ->..
# setuptools/develop -> ?
# pip install:
# copies source tree to a tempdir before running egg_info/etc
# if .git isn't copied too, 'git describe' will fail
# then does setup.py bdist_wheel, or sometimes setup.py install
# setup.py egg_info -> ?
# pip install -e . and setuptool/editable_wheel will invoke build_py
# but the build_py command is not expected to copy any files.
# we override different "build_py" commands for both environments
if 'build_py' in cmds:
_build_py = cmds['build_py']
else:
from setuptools.command.build_py import build_py as _build_py
class cmd_build_py(_build_py):
def run(self):
root = get_root()
cfg = get_config_from_root(root)
versions = get_versions()
_build_py.run(self)
if getattr(self, "editable_mode", False):
# During editable installs `.py` and data files are
# not copied to build_lib
return
# now locate _version.py in the new build/ directory and replace
# it with an updated value
if cfg.versionfile_build:
target_versionfile = os.path.join(self.build_lib,
cfg.versionfile_build)
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
cmds["build_py"] = cmd_build_py
if 'build_ext' in cmds:
_build_ext = cmds['build_ext']
else:
from setuptools.command.build_ext import build_ext as _build_ext
class cmd_build_ext(_build_ext):
def run(self):
root = get_root()
cfg = get_config_from_root(root)
versions = get_versions()
_build_ext.run(self)
if self.inplace:
# build_ext --inplace will only build extensions in
# build/lib<..> dir with no _version.py to write to.
# As in place builds will already have a _version.py
# in the module dir, we do not need to write one.
return
# now locate _version.py in the new build/ directory and replace
# it with an updated value
if not cfg.versionfile_build:
return
target_versionfile = os.path.join(self.build_lib,
cfg.versionfile_build)
if not os.path.exists(target_versionfile):
print(f"Warning: {target_versionfile} does not exist, skipping "
"version update. This can happen if you are running build_ext "
"without first running build_py.")
return
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
cmds["build_ext"] = cmd_build_ext
if "cx_Freeze" in sys.modules: # cx_freeze enabled?
from cx_Freeze.dist import build_exe as _build_exe
# nczeczulin reports that py2exe won't like the pep440-style string
# as FILEVERSION, but it can be used for PRODUCTVERSION, e.g.
# setup(console=[{
# "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION
# "product_version": versioneer.get_version(),
# ...
class cmd_build_exe(_build_exe):
def run(self):
root = get_root()
cfg = get_config_from_root(root)
versions = get_versions()
target_versionfile = cfg.versionfile_source
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
_build_exe.run(self)
os.unlink(target_versionfile)
with open(cfg.versionfile_source, "w") as f:
LONG = LONG_VERSION_PY[cfg.VCS]
f.write(LONG %
{"DOLLAR": "$",
"STYLE": cfg.style,
"TAG_PREFIX": cfg.tag_prefix,
"PARENTDIR_PREFIX": cfg.parentdir_prefix,
"VERSIONFILE_SOURCE": cfg.versionfile_source,
})
cmds["build_exe"] = cmd_build_exe
del cmds["build_py"]
if 'py2exe' in sys.modules: # py2exe enabled?
try:
from py2exe.setuptools_buildexe import py2exe as _py2exe
except ImportError:
from py2exe.distutils_buildexe import py2exe as _py2exe
class cmd_py2exe(_py2exe):
def run(self):
root = get_root()
cfg = get_config_from_root(root)
versions = get_versions()
target_versionfile = cfg.versionfile_source
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
_py2exe.run(self)
os.unlink(target_versionfile)
with open(cfg.versionfile_source, "w") as f:
LONG = LONG_VERSION_PY[cfg.VCS]
f.write(LONG %
{"DOLLAR": "$",
"STYLE": cfg.style,
"TAG_PREFIX": cfg.tag_prefix,
"PARENTDIR_PREFIX": cfg.parentdir_prefix,
"VERSIONFILE_SOURCE": cfg.versionfile_source,
})
cmds["py2exe"] = cmd_py2exe
# sdist farms its file list building out to egg_info
if 'egg_info' in cmds:
_egg_info = cmds['egg_info']
else:
from setuptools.command.egg_info import egg_info as _egg_info
class cmd_egg_info(_egg_info):
def find_sources(self):
# egg_info.find_sources builds the manifest list and writes it
# in one shot
super().find_sources()
# Modify the filelist and normalize it
root = get_root()
cfg = get_config_from_root(root)
self.filelist.append('versioneer.py')
if cfg.versionfile_source:
# There are rare cases where versionfile_source might not be
# included by default, so we must be explicit
self.filelist.append(cfg.versionfile_source)
self.filelist.sort()
self.filelist.remove_duplicates()
# The write method is hidden in the manifest_maker instance that
# generated the filelist and was thrown away
# We will instead replicate their final normalization (to unicode,
# and POSIX-style paths)
from setuptools import unicode_utils
normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/')
for f in self.filelist.files]
manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt')
with open(manifest_filename, 'w') as fobj:
fobj.write('\n'.join(normalized))
cmds['egg_info'] = cmd_egg_info
# we override different "sdist" commands for both environments
if 'sdist' in cmds:
_sdist = cmds['sdist']
else:
from setuptools.command.sdist import sdist as _sdist
class cmd_sdist(_sdist):
def run(self):
versions = get_versions()
self._versioneer_generated_versions = versions
# unless we update this, the command will keep using the old
# version
self.distribution.metadata.version = versions["version"]
return _sdist.run(self)
def make_release_tree(self, base_dir, files):
root = get_root()
cfg = get_config_from_root(root)
_sdist.make_release_tree(self, base_dir, files)
# now locate _version.py in the new base_dir directory
# (remembering that it may be a hardlink) and replace it with an
# updated value
target_versionfile = os.path.join(base_dir, cfg.versionfile_source)
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile,
self._versioneer_generated_versions)
cmds["sdist"] = cmd_sdist
return cmds
CONFIG_ERROR = """
setup.cfg is missing the necessary Versioneer configuration. You need
a section like:
[versioneer]
VCS = git
style = pep440
versionfile_source = src/myproject/_version.py
versionfile_build = myproject/_version.py
tag_prefix =
parentdir_prefix = myproject-
You will also need to edit your setup.py to use the results:
import versioneer
setup(version=versioneer.get_version(),
cmdclass=versioneer.get_cmdclass(), ...)
Please read the docstring in ./versioneer.py for configuration instructions,
edit setup.cfg, and re-run the installer or 'python versioneer.py setup'.
"""
SAMPLE_CONFIG = """
# See the docstring in versioneer.py for instructions. Note that you must
# re-run 'versioneer.py setup' after changing this section, and commit the
# resulting files.
[versioneer]
#VCS = git
#style = pep440
#versionfile_source =
#versionfile_build =
#tag_prefix =
#parentdir_prefix =
"""
OLD_SNIPPET = """
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
"""
INIT_PY_SNIPPET = """
from . import {0}
__version__ = {0}.get_versions()['version']
"""
def do_setup():
"""Do main VCS-independent setup function for installing Versioneer."""
root = get_root()
try:
cfg = get_config_from_root(root)
except (OSError, configparser.NoSectionError,
configparser.NoOptionError) as e:
if isinstance(e, (OSError, configparser.NoSectionError)):
print("Adding sample versioneer config to setup.cfg",
file=sys.stderr)
with open(os.path.join(root, "setup.cfg"), "a") as f:
f.write(SAMPLE_CONFIG)
print(CONFIG_ERROR, file=sys.stderr)
return 1
print(" creating %s" % cfg.versionfile_source)
with open(cfg.versionfile_source, "w") as f:
LONG = LONG_VERSION_PY[cfg.VCS]
f.write(LONG % {"DOLLAR": "$",
"STYLE": cfg.style,
"TAG_PREFIX": cfg.tag_prefix,
"PARENTDIR_PREFIX": cfg.parentdir_prefix,
"VERSIONFILE_SOURCE": cfg.versionfile_source,
})
ipy = os.path.join(os.path.dirname(cfg.versionfile_source),
"__init__.py")
if os.path.exists(ipy):
try:
with open(ipy, "r") as f:
old = f.read()
except OSError:
old = ""
module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0]
snippet = INIT_PY_SNIPPET.format(module)
if OLD_SNIPPET in old:
print(" replacing boilerplate in %s" % ipy)
with open(ipy, "w") as f:
f.write(old.replace(OLD_SNIPPET, snippet))
elif snippet not in old:
print(" appending to %s" % ipy)
with open(ipy, "a") as f:
f.write(snippet)
else:
print(" %s unmodified" % ipy)
else:
print(" %s doesn't exist, ok" % ipy)
ipy = None
# Make VCS-specific changes. For git, this means creating/changing
# .gitattributes to mark _version.py for export-subst keyword
# substitution.
do_vcs_install(cfg.versionfile_source, ipy)
return 0
def scan_setup_py():
"""Validate the contents of setup.py against Versioneer's expectations."""
found = set()
setters = False
errors = 0
with open("setup.py", "r") as f:
for line in f.readlines():
if "import versioneer" in line:
found.add("import")
if "versioneer.get_cmdclass()" in line:
found.add("cmdclass")
if "versioneer.get_version()" in line:
found.add("get_version")
if "versioneer.VCS" in line:
setters = True
if "versioneer.versionfile_source" in line:
setters = True
if len(found) != 3:
print("")
print("Your setup.py appears to be missing some important items")
print("(but I might be wrong). Please make sure it has something")
print("roughly like the following:")
print("")
print(" import versioneer")
print(" setup( version=versioneer.get_version(),")
print(" cmdclass=versioneer.get_cmdclass(), ...)")
print("")
errors += 1
if setters:
print("You should remove lines like 'versioneer.VCS = ' and")
print("'versioneer.versionfile_source = ' . This configuration")
print("now lives in setup.cfg, and should be removed from setup.py")
print("")
errors += 1
return errors
def setup_command():
"""Set up Versioneer and exit with appropriate error code."""
errors = do_setup()
errors += scan_setup_py()
sys.exit(1 if errors else 0)
if __name__ == "__main__":
cmd = sys.argv[1]
if cmd == "setup":
setup_command()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment