Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.diffusion."""
from absl.testing import absltest
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.collocated import diffusion
class DiffusionTest(test_util.TestCase):
"""Some simple sanity tests for diffusion on constant fields."""
def test_explicit_diffusion(self):
nu = 1.
shape = (101, 101, 101)
offset = (0.5, 0.5, 0.5)
step = (1., 1., 1.)
grid = grids.Grid(shape, step)
c = grids.GridVariable(
array=grids.GridArray(jnp.ones(shape), offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
diffused = diffusion.diffuse(c, nu)
expected = grids.GridArray(jnp.zeros_like(diffused.data), offset, grid)
self.assertAllClose(expected, diffused)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Examples of defining equations."""
from typing import Callable, Optional
import jax
from jax_cfd.base import advection as base_advection
from jax_cfd.base import grids
from jax_cfd.collocated import diffusion
from jax_cfd.collocated import pressure
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
ConvectFn = Callable[[GridVariableVector], GridArrayVector]
DiffuseFn = Callable[[GridVariable, float], GridArray]
ForcingFn = Callable[[GridVariableVector], GridArrayVector]
def sum_fields(*args):
return jax.tree_util.tree_map(lambda *a: sum(a), *args)
def semi_implicit_navier_stokes(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffuse: DiffuseFn = diffusion.diffuse,
pressure_solve: Callable = pressure.solve_cg,
forcing: Optional[ForcingFn] = None,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
del grid # unused
if convect is None:
def convect(v): # pylint: disable=function-redefined
return tuple(
base_advection.advect_van_leer_using_limiters(u, v, dt) for u in v)
convect = jax.named_call(convect, name='convection')
diffuse = jax.named_call(diffuse, name='diffusion')
pressure_projection = jax.named_call(
pressure.projection, name='pressure')
@jax.named_call
def navier_stokes_step(v: GridVariableVector) -> GridVariableVector:
"""Computes state at time `t + dt` using first order time integration."""
# Collect the acceleration terms
convection = convect(v)
accelerations = [convection]
if viscosity is not None:
diffusion_ = tuple(diffuse(u, viscosity / density) for u in v)
accelerations.append(diffusion_)
if forcing is not None:
# TODO(shoyer): include time in state?
force = forcing(v)
accelerations.append(tuple(f / density for f in force))
dvdt = sum_fields(*accelerations)
# Update v by taking a time step
v = tuple(
grids.GridVariable(u.array + dudt * dt, u.bc)
for u, dudt in zip(v, dvdt))
# Pressure projection to incompressible velocity field
v = pressure_projection(v, pressure_solve)
return v
return navier_stokes_step
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.equations."""
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.collocated import advection
from jax_cfd.collocated import equations
from jax_cfd.collocated import pressure
import numpy as np
def zero_velocity_field(grid: grids.Grid) -> grids.GridVariableVector:
"""Returns an all-zero periodic velocity fields."""
offset = grid.cell_center
data = jnp.zeros(grid.shape)
return tuple(
grids.GridVariable(grids.GridArray(data, offset, grid),
boundaries.periodic_boundary_conditions(grid.ndim))
for _ in range(grid.ndim))
def sinusoidal_velocity_field(grid: grids.Grid) -> grids.GridVariableVector:
"""Returns a divergence-free velocity flow on `grid`."""
offset = grid.cell_center
mesh = grid.mesh(offset)
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
vs = tuple(
jnp.sin(2. * np.pi * g / s) for g, s in zip(mesh, mesh_size))
return tuple(
grids.GridVariable(grids.GridArray(v, offset, grid),
boundaries.periodic_boundary_conditions(grid.ndim))
for v in vs[1:] + vs[:1])
def gaussian_force_field(grid) -> grids.GridArrayVector:
"""Returns a 'Gaussian-shaped' force field in the 'x' direction."""
offset = grid.cell_center
mesh = grid.mesh(offset)
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
data = jnp.exp(-sum([jnp.square(x / s - .5)
for x, s in zip(mesh, mesh_size)]) * 100.)
v = [grids.GridArray(data, offset, grid)]
for _ in range(1, grid.ndim):
v.append(grids.GridArray(jnp.zeros(grid.shape), offset, grid))
return tuple(v)
def gaussian_forcing(v: grids.GridVariableVector) -> grids.GridArrayVector:
"""Returns Gaussian field forcing."""
grid = grids.consistent_grid(*v)
return gaussian_force_field(grid)
def momentum(v: grids.GridVariableVector, density: float) -> grids.Array:
"""Returns the momentum due to velocity field `v`."""
grid = grids.consistent_grid(*v)
return jnp.array([u.data for u in v]).sum() * density * jnp.array(
grid.step).prod()
def _convect_upwind(v: grids.GridVariableVector) -> grids.GridArrayVector:
return tuple(advection.advect_linear(u, v) for u in v)
class SemiImplicitNavierStokesTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='semi_implicit_sinusoidal_velocity_base',
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=None,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=5e-3),
dict(testcase_name='semi_implicit_gaussian_force_upwind',
velocity=zero_velocity_field,
forcing=gaussian_forcing,
shape=(40, 40, 40),
step=(1., 1., 1.),
density=1.,
viscosity=None,
convect=_convect_upwind,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=100,
divergence_atol=1e-4,
momentum_atol=2e-4),
)
def test_divergence_and_momentum(
self, velocity, forcing, shape, step, density, viscosity, convect,
pressure_solve, dt, time_steps, divergence_atol, momentum_atol,
):
grid = grids.Grid(shape, step)
navier_stokes = equations.semi_implicit_navier_stokes(
density,
viscosity,
dt,
grid,
convect=convect,
pressure_solve=pressure_solve,
forcing=forcing)
v_initial = velocity(grid)
v_final = funcutils.repeated(navier_stokes, time_steps)(v_initial)
divergence = fd.centered_divergence(v_final)
self.assertLess(jnp.max(divergence.data), divergence_atol)
initial_momentum = momentum(v_initial, density)
final_momentum = momentum(v_final, density)
if forcing is not None:
expected_change = (
jnp.array([f.data for f in forcing(v_initial)]).sum() *
jnp.array(grid.step).prod() * dt * time_steps)
else:
expected_change = 0
self.assertAllClose(
initial_momentum + expected_change, final_momentum, atol=momentum_atol)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare initial conditions for simulations."""
from typing import Union
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import filter_utils
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.collocated import pressure
import numpy as np
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
Array = Union[np.ndarray, jax.Array]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions
def _log_normal_pdf(x, mode, variance=.25):
"""Unscaled PDF for a log normal given `mode` and log variance 1."""
mean = jnp.log(mode) + variance
logx = jnp.log(x)
return jnp.exp(-(mean - logx)**2 / 2 / variance - logx)
def _max_speed(v):
return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max()
def filtered_velocity_field(
rng_key: grids.Array,
grid: grids.Grid,
maximum_velocity: float = 1,
peak_wavenumber: float = 3,
iterations: int = 3,
) -> GridVariableVector:
"""Create divergence-free velocity fields with appropriate spectral filtering.
Modified version for collocated variables.
Args:
rng_key: key for seeding the random initial velocity field.
grid: the grid on which the velocity field is defined.
maximum_velocity: the maximum speed in the velocity field.
peak_wavenumber: the velocity field will be filtered so that the largest
magnitudes are associated with this wavenumber.
iterations: the number of repeated pressure projection and normalization
iterations to apply.
Returns:
A divergence free velocity field with the given maximum velocity. Associates
periodic boundary conditions with the velocity field components.
"""
# Log normal distribution peaked at `peak_wavenumber`. Note that we have to
# divide by `k ** (ndim - 1)` to account for the volume of the
# `ndim - 1`-sphere of values with wavenumber `k`.
def spectral_density(k):
return _log_normal_pdf(k, peak_wavenumber) / k ** (grid.ndim - 1)
keys = jax.random.split(rng_key, grid.ndim)
velocity_components = []
boundary_conditions = []
for k in keys:
noise = jax.random.normal(k, grid.shape)
velocity_components.append(
filter_utils.filter(spectral_density, noise, grid))
boundary_conditions.append(
boundaries.periodic_boundary_conditions(grid.ndim))
# Place values on cell-centered grid
velocity = tuple(
grids.GridVariable(grids.GridArray(u, grid.cell_center, grid), bc)
for u, bc in zip(velocity_components, boundary_conditions))
def project_and_normalize(v: GridVariableVector):
v = pressure.projection(v)
vmax = _max_speed(v)
v = tuple(
grids.GridVariable(maximum_velocity * u.array / vmax, u.bc) for u in v)
return v
# Due to numerical precision issues, we repeatedly normalize and project the
# velocity field. This ensures that it is divergence-free and achieves the
# specified maximum velocity.
return funcutils.repeated(project_and_normalize, iterations)(velocity)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.initial_conditions."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.collocated import initial_conditions
import numpy as np
def get_grid(grid_size, ndim, domain_size_multiple=1):
domain = ((0, 2 * np.pi * domain_size_multiple),) * ndim
shape = (grid_size,) * ndim
return grids.Grid(shape=shape, domain=domain)
class InitialConditionsTest(test_util.TestCase):
@parameterized.parameters(
dict(seed=3232,
grid=get_grid(128, ndim=3),
maximum_velocity=1.,
peak_wavenumber=2),
)
def test_filtered_velocity_field(
self, seed, grid, maximum_velocity, peak_wavenumber):
v = initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(seed), grid, maximum_velocity, peak_wavenumber)
actual_maximum_velocity = jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max()
max_divergence = fd.centered_divergence(v).data.max()
# Assert that initial velocity is divergence free
self.assertAllClose(0., max_divergence, atol=1e-4)
# Assert that the specified maximum velocity is obtained.
self.assertAllClose(maximum_velocity, actual_maximum_velocity, atol=1e-4)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for computing and applying pressure."""
from typing import Callable, Optional
import jax.numpy as jnp
import jax.scipy.sparse.linalg
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
Array = grids.Array
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
# TODO(pnorgaard) Implement bicgstab for non-symmetric operators
def solve_cg(
v: GridVariableVector,
q0: GridVariable,
rtol: float = 1e-6,
atol: float = 1e-6,
maxiter: Optional[int] = None) -> GridArray:
"""Conjugate gradient solve for the pressure such that continuity is enforced.
Returns a pressure correction `q` such that `div(v - grad(q)) == 0`.
The relationship between `q` and our actual pressure estimate is given by
`p = q * density / dt`.
Args:
v: the velocity field.
q0: an initial value, or "guess" for the pressure correction. A common
choice is the correction from the previous time step. Also specifies the
boundary conditions on `q`.
rtol: relative tolerance for convergence.
atol: absolute tolerance for convergence.
maxiter: optional int, the maximum number of iterations to perform.
Returns:
A pressure correction `q` such that `div(v - grad(q))` is zero.
"""
rhs = fd.centered_divergence(v)
def laplacian_with_bcs(array: GridArray) -> GridArray:
if not boundaries.has_all_periodic_boundary_conditions(q0):
raise ValueError(
'Laplacian operator implementation requires periodic bc.')
variable = grids.GridVariable(array, q0.bc)
gradient = fd.central_difference(variable, axis=None)
gradient = tuple(grids.GridVariable(g, q0.bc) for g in gradient)
return fd.centered_divergence(gradient)
q, _ = jax.scipy.sparse.linalg.cg(
laplacian_with_bcs,
rhs,
x0=q0.array,
tol=rtol,
atol=atol,
maxiter=maxiter)
return q
def projection(
v: GridVariableVector,
solve: Callable = solve_cg,
) -> GridVariableVector:
"""Apply pressure projection to make a velocity field divergence free."""
grid = grids.consistent_grid(*v)
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
q0 = grids.GridVariable(
grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid),
pressure_bc)
q = solve(v, q0)
q = grids.GridVariable(q, pressure_bc)
q_grad = fd.central_difference(q, axis=None)
v_projected = tuple(
grids.GridVariable(u.array - q_g, u.bc) for u, q_g in zip(v, q_grad))
return v_projected
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.pressure."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.collocated import pressure
import numpy as np
USE_FLOAT64 = True
solve_cg = functools.partial(pressure.solve_cg, atol=1e-6, maxiter=10**5)
class PressureTest(test_util.TestCase):
def setUp(self):
jax.config.update('jax_enable_x64', USE_FLOAT64)
super(PressureTest, self).setUp()
@parameterized.named_parameters(
dict(testcase_name='_1D_cg',
shape=(301,),
solve=solve_cg,
step=(.1,),
seed=111),
dict(testcase_name='_2D_cg',
shape=(100, 100),
solve=solve_cg,
step=(1., 1.),
seed=222),
dict(testcase_name='_3D_cg',
shape=(10, 10, 10),
solve=solve_cg,
step=(.1, .1, .1),
seed=333),
)
def test_pressure_correction_periodic_bc(
self, shape, solve, step, seed):
"""Returned velocity should be divergence free."""
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
# The uncorrected velocity is a 1 + a small amount of noise in each
# dimension.
ks = jax.random.split(jax.random.PRNGKey(seed), 2 * len(shape))
offset = grid.cell_center
v = tuple(
grids.GridArray(1. + .3 * jax.random.normal(k, shape), offset, grid)
for k in ks[:grid.ndim])
v = tuple(grids.GridVariable(u, bc) for u in v)
v_corrected = pressure.projection(v, solve)
# The corrected velocity should be divergence free.
div = fd.centered_divergence(v_corrected)
for u, u_corrected in zip(v, v_corrected):
np.testing.assert_allclose(u.offset, u_corrected.offset)
np.testing.assert_allclose(div.data, 0., atol=1e-4)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for working with generated datasets from JAX-CFD."""
import jax_cfd.data.evaluation
import jax_cfd.data.visualization
import jax_cfd.data.xarray_utils
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods for evaluation of trained models."""
from typing import Sequence, Tuple
import jax
import jax.numpy as jnp
from jax_cfd.data import xarray_utils as xr_utils
import numpy as np
import xarray
# pytype complains about valid operations with xarray (e.g., see b/153704639),
# so it isn't worth the trouble of running it.
# pytype: skip-file
def absolute_error(
array: xarray.DataArray,
eval_model_name: str = 'learned',
target_model_name: str = 'ground_truth',
) -> xarray.DataArray:
"""Computes absolute error between to be evaluated and target models.
Args:
array: xarray.DataArray that contains model dimension with `eval_model_name`
and `target_model_name` coordinates.
eval_model_name: name of the model that is being evaluated.
target_model_name: name of the model representing the ground truth values.
Returns:
xarray.DataArray containing absolute value of errors between
`eval_model_name` and `target_model_name` models.
"""
predicted = array.sel(model=eval_model_name)
target = array.sel(model=target_model_name)
return abs(predicted - target).rename('_'.join([predicted.name, 'error']))
def state_correlation(
array: xarray.DataArray,
eval_model_name: str = 'learned',
target_model_name: str = 'ground_truth',
non_state_dims: Tuple[str, ...] = (xr_utils.XR_SAMPLE_NAME,
xr_utils.XR_TIME_NAME),
non_state_dims_to_average: Tuple[str, ...] = (xr_utils.XR_SAMPLE_NAME,),
) -> xarray.DataArray:
"""Computes normalized correlation of `array` between target and eval models.
The dimensions of the `array` are expected to consists of state dimensions
that are interpreted as a vector parametrizing the configuration of the system
and `non_state_dims`, that optionally are averaged over if included in
`non_state_dims_to_average`.
Args:
array: xarray.DataArray that contains model dimension with `eval_model_name`
and `target_model_name` coordinates.
eval_model_name: name of the model that is being evaluated.
target_model_name: name of the model representing the ground truth values.
non_state_dims: tuple of dimension names that are not a part of the state.
non_state_dims_to_average: tuple of `non_state_dims` to average over.
Returns:
xarray.DataArray containing normalized correlation between `eval_model_name`
and `target_model_name` models.
"""
predicted = array.sel(model=eval_model_name)
target = array.sel(model=target_model_name)
state_dims = list(set(predicted.dims) - set(non_state_dims))
predicted = xr_utils.normalize(predicted, state_dims)
target = xr_utils.normalize(target, state_dims)
result = (predicted * target).sum(state_dims).mean(non_state_dims_to_average)
return result.rename('_'.join([array.name, 'correlation']))
def approximate_quantiles(ds, quantile_thresholds):
"""Approximate quantiles of all arrays in the given xarray.Dataset."""
# quantiles can't be done in a blocked fashion in the current version of dask,
# so for now select only the first time step and create a single chunk for
# each array.
return ds.isel(time=0).chunk(-1).quantile(q=quantile_thresholds)
def below_error_threshold(
array: xarray.DataArray,
threshold: xarray.DataArray,
eval_model_name: str = 'learned',
target_model_name: str = 'ground_truth',
) -> xarray.DataArray:
"""Compute if eval model error is below a threshold based on the target."""
predicted = array.sel(model=eval_model_name)
target = array.sel(model=target_model_name)
return abs(predicted - target) <= threshold
def average(
array: xarray.DataArray,
ndim: int,
non_spatial_dims: Tuple[str, ...] = (xr_utils.XR_SAMPLE_NAME,)
) -> xarray.DataArray:
"""Computes spatial and `non_spatial_dims` mean over `array`.
Since complex values are not supported in netcdf format we currently check if
imaginary part can be discarded, otherwise an error is raised.
Args:
array: xarray.DataArray to take a mean of. Expected to have `ndim` spatial
dimensions with names as in `xr_utils.XR_SPATIAL_DIMS`.
ndim: number of spatial dimensions.
non_spatial_dims: tuple of dimension names to average besides space.
Returns:
xarray.DataArray with `ndim` spatial dimensions and `non_spatial_dims`
reduced to mean values.
Raises:
ValueError: if `array` contains non-real imaginary values.
"""
dims = list(non_spatial_dims) + list(xr_utils.XR_SPATIAL_DIMS[:ndim])
dims = [dim for dim in dims if dim in array.dims]
mean_values = array.mean(dims)
if np.iscomplexobj(mean_values):
raise ValueError('complex values are not supported.')
return mean_values
def energy_spectrum_metric(threshold=0.01):
"""Computes an energy spectrum metric that checks if a simulation failed."""
@jax.jit
def _energy_spectrum_metric(arr, ground_truth):
diff = jnp.abs(jnp.log(arr) - jnp.log(ground_truth))
metric = jnp.sum(jnp.where(ground_truth > threshold, diff, 0), axis=-1)
cutoff = jnp.sum(
jnp.where((arr > threshold) & (ground_truth < threshold),
jnp.abs(jnp.log(arr)), 0),
axis=-1)
return metric + cutoff
energy_spectrum_ds = lambda a, b: xarray.apply_ufunc( # pylint: disable=g-long-lambda
_energy_spectrum_metric, a, b, input_core_dims=[['kx'], ['kx']]).mean(
dim='sample')
return energy_spectrum_ds
def u_x_correlation_metric(threshold=0.5):
"""Computes a spacial spectrum metric that checks if a simulation failed."""
@jax.jit
def _u_x_correlation_metric(arr, ground_truth):
diff = (jnp.abs(arr - ground_truth))
metric = jnp.sum(
jnp.where(jnp.abs(ground_truth) > threshold, diff, 0), axis=-1)
cutoff = jnp.sum(
jnp.where(
(jnp.abs(arr) > threshold) & (jnp.abs(ground_truth) < threshold),
jnp.abs(arr), 0),
axis=-1)
return metric + cutoff
u_x_correlation_ds = lambda a, b: xarray.apply_ufunc( # pylint: disable=g-long-lambda
_u_x_correlation_metric, a, b, input_core_dims=[['dx'], ['dx']]).mean(
dim='sample')
return u_x_correlation_ds
def temporal_autocorrelation(array):
"""Computes temporal autocorrelation of array."""
dt = array['time'][1] - array['time'][0]
length = array.sizes['time']
subsample = max(1, int(1. / dt))
def _autocorrelation(array):
def _corr(x, d):
del x
arr1 = jnp.roll(array, d, 0)
ans = arr1 * array
ans = jnp.sum(
jnp.where(
jnp.arange(length).reshape(-1, 1, 1, 1) >= d, ans / length, 0),
axis=0)
return d, ans
_, full_result = jax.lax.scan(_corr, 0, jnp.arange(0, length, subsample))
return full_result
full_result = jax.jit(_autocorrelation)(
jnp.array(array.transpose('time', 'sample', 'x', 'model').u))
full_result = xarray.Dataset(
data_vars=dict(t_corr=(['time', 'sample', 'x', 'model'], full_result)),
coords={
'dt': np.array(array.time[slice(None, None, subsample)]),
'sample': array.sample,
'x': array.x,
'model': array.model
})
return full_result
def u_t_correlation_metric(threshold=0.5):
"""Computes a temporal spectrum metric that checks if a simulation failed."""
@jax.jit
def _u_t_correlation_metric(arr, ground_truth):
diff = (jnp.abs(arr - ground_truth))
metric = jnp.sum(
jnp.where(jnp.abs(ground_truth) > threshold, diff, 0), axis=-1)
cutoff = jnp.sum(
jnp.where(
(jnp.abs(arr) > threshold) & (jnp.abs(ground_truth) < threshold),
jnp.abs(arr), 0),
axis=-1)
return jnp.mean(metric + cutoff)
return _u_t_correlation_metric
def compute_summary_dataset(
model_ds: xarray.Dataset,
ground_truth_ds: xarray.Dataset,
quantile_thresholds: Sequence[float] = (0.1, 1.0),
non_spatial_dims: Tuple[str, ...] = (xr_utils.XR_SAMPLE_NAME,)
) -> xarray.Dataset:
"""Computes sample and space averaged summaries of predictions and errors.
Args:
model_ds: dataset containing trajectories unrolled using the model.
ground_truth_ds: dataset containing ground truth trajectories.
quantile_thresholds: quantile thresholds to use for "within error" metrics.
non_spatial_dims: tuple of dimension names to average besides space.
Returns:
xarray.Dataset containing observables and absolute value errors
averaged over sample and spatial dimensions.
"""
ndim = ground_truth_ds.attrs['ndim']
eval_model_name = 'eval_model'
target_model_name = 'ground_truth'
combined_dataset = xarray.concat([model_ds, ground_truth_ds], dim='model')
combined_dataset.coords['model'] = [eval_model_name, target_model_name]
combined_dataset = combined_dataset.sel(time=slice(None, 500))
summaries = [combined_dataset[u] for u in xr_utils.XR_VELOCITY_NAMES[:ndim]]
spectrum = xr_utils.energy_spectrum(combined_dataset).rename(
'energy_spectrum')
summaries += [
xr_utils.kinetic_energy(combined_dataset),
xr_utils.speed(combined_dataset),
spectrum,
]
# TODO(dkochkov) Check correlations in NS and enable it for 2d and 3d.
if ndim == 1:
correlations = xr_utils.velocity_spatial_correlation(combined_dataset, 'x')
time_correlations = temporal_autocorrelation(combined_dataset)
summaries += [correlations[variable] for variable in correlations]
u_x_corr_sum = [
xarray.DataArray((u_x_correlation_metric(threshold)( # pylint: disable=g-complex-comprehension
correlations.sel(model=eval_model_name),
correlations.sel(model=target_model_name))).u_x_correlation)
for threshold in [0.5]
]
if not time_correlations.t_corr.isnull().any():
# autocorrelation is a constant, so it is expanded to be part of summaries
u_t_corr_sum = [
xarray.ones_like(u_x_corr_sum[0]).rename('autocorrelation') * # pylint: disable=g-complex-comprehension
u_t_correlation_metric(threshold)(
jnp.array(time_correlations.t_corr.sel(model=eval_model_name)),
jnp.array(time_correlations.t_corr.sel(model=target_model_name)))
for threshold in [0.5]
]
else:
# if the trajectory goes to nan, it just reports a large number
u_t_corr_sum = [
xarray.ones_like(u_x_corr_sum[0]).rename('autocorrelation') * np.infty
for threshold in [0.5]
]
energy_sum = [
energy_spectrum_metric(threshold)( # pylint: disable=g-complex-comprehension
spectrum.sel(model=eval_model_name, kx=slice(0, spectrum.kx.max())),
spectrum.sel(
model=target_model_name,
kx=slice(0, spectrum.kx.max()))).rename('energy_spectrum_%f' %
threshold)
for threshold in [0.001, 0.01, 0.1, 1.0, 10]
] # pylint: disable=g-complex-comprehension
custom_summaries = u_x_corr_sum + energy_sum + u_t_corr_sum
if ndim == 2:
summaries += [
xr_utils.enstrophy_2d(combined_dataset),
xr_utils.vorticity_2d(combined_dataset),
xr_utils.isotropic_energy_spectrum(
combined_dataset,
average_dims=non_spatial_dims).rename('energy_spectrum')
]
if ndim >= 2:
custom_summaries = []
mean_summaries = [
average(s.sel(model=eval_model_name), ndim).rename(s.name + '_mean')
for s in summaries
]
error_summaries = [
average(absolute_error(s, eval_model_name, target_model_name), ndim)
for s in summaries
]
correlation_summaries = [
state_correlation(s, eval_model_name, target_model_name)
for s in summaries
if s.name in xr_utils.XR_VELOCITY_NAMES + ('vorticity',)
]
summaries_ds = xarray.Dataset({array.name: array for array in summaries})
thresholds = approximate_quantiles(
summaries_ds, quantile_thresholds=quantile_thresholds).compute()
threshold_summaries = []
for threshold_quantile in quantile_thresholds:
for summary in summaries:
name = summary.name
error_threshold = thresholds[name].sel(
quantile=threshold_quantile, drop=True)
below_error = below_error_threshold(summary, error_threshold,
eval_model_name, target_model_name)
below_error.name = f'{name}_within_q={threshold_quantile}'
threshold_summaries.append(average(below_error, ndim))
all_summaries = (
mean_summaries + error_summaries + threshold_summaries +
correlation_summaries + custom_summaries)
return xarray.Dataset({array.name: array for array in all_summaries})
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Visualization utilities."""
from typing import Any, BinaryIO, Callable, Optional, List, Tuple, Union
from jax_cfd.base import grids
import matplotlib as mpl
import matplotlib.cm as cm
import numpy as np
import PIL.Image as Image
import seaborn as sns
NormFn = Callable[[grids.Array, int], mpl.colors.Normalize]
def quantile_normalize_fn(
image_data: grids.Array,
image_num: int,
quantile: float = 0.999
) -> mpl.colors.Normalize:
"""Returns `mpl.colors.Normalize` object that range defined by data quantile.
Args:
image_data: data for which `Normalize` object is produced.
image_num: number of frame in the series. Not used.
quantile: quantile that should be included in the range.
Returns:
`mpl.colors.Normalize` that covers the range of values that include quantile
of `image_data` values.
"""
del image_num # not used by `quantile_normalize_fn`.
max_to_include = np.quantile(abs(image_data), quantile)
norm = mpl.colors.Normalize(vmin=-max_to_include, vmax=max_to_include)
return norm
def resize_image(
image: Image.Image,
longest_side: int,
resample: int = Image.Resampling.NEAREST,
) -> Image.Image:
"""Resize an image, preserving its aspect ratio."""
resize_factor = longest_side / max(image.size)
new_size = tuple(round(s * resize_factor) for s in image.size)
return image.resize(new_size, resample)
def trajectory_to_images(
trajectory: grids.Array,
compute_norm_fn: NormFn = quantile_normalize_fn,
cmap: mpl.colors.ListedColormap = sns.cm.icefire, # pytype: disable=module-attr
longest_side: Optional[int] = None,
) -> List[Image.Image]:
"""Converts scalar trajectory with leading time axis into a list of images."""
images = []
for i, image_data in enumerate(trajectory):
norm = compute_norm_fn(image_data, i)
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
img = Image.fromarray(mappable.to_rgba(image_data, bytes=True))
if longest_side is not None:
img = resize_image(img, longest_side)
images.append(img)
return images
# TODO(dkochkov) consider generalizing this to a general facet.
def horizontal_facet(
separate_images: List[List[Image.Image]],
relative_separation_width: float,
separation_rgb: Tuple[int, int, int] = (255, 255, 255)
) -> List[Image.Image]:
"""Stitches separate images into a single one with a separation strip.
Args:
separate_images: lists of images each representing time series. All images
must have the same size.
relative_separation_width: width of the separation defined as a fraction of
a separate image.
separation_rgb: rgb color code of the separation strip to add between
adjacent images.
Returns:
list of merged images that contain images passed as `separate_images` with
a separating strip.
"""
images = []
for frames in zip(*separate_images):
images_to_combine = len(frames)
separation_width = round(frames[0].width * relative_separation_width)
image_height = frames[0].height
image_width = (frames[0].width * images_to_combine +
separation_width * (images_to_combine - 1))
full_im = Image.new('RGB', (image_width, image_height))
sep_im = Image.new('RGB', (separation_width, image_height), separation_rgb)
full_im = Image.new('RGB', (image_width, image_height))
width_offset = 0
height_offset = 0
for frame in frames:
full_im.paste(frame, (width_offset, height_offset))
width_offset += frame.width
if width_offset < full_im.width:
full_im.paste(sep_im, (width_offset, height_offset))
width_offset += sep_im.width
images.append(full_im)
return images
def save_movie(
images: List[Image.Image],
output_path: Union[str, BinaryIO],
duration: float = 150.,
loop: int = 0,
**kwargs: Any
):
"""Saves `images` as a movie of duration `duration` to `output_path`.
Args:
images: list of images representing time series that will be saved as movie.
output_path: file handle or cns path to where save the movie.
duration: duration of the movie in milliseconds.
loop: number of times to loop the movie. 0 interpreted as indefinite.
**kwargs: additional keyword arguments to be passed to `Image.save`.
"""
images[0].save(output_path, save_all=True, append_images=images[1:],
duration=duration, loop=loop, **kwargs)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.data.visualization."""
import os.path
from absl.testing import absltest
from jax_cfd.base import test_util
from jax_cfd.data import visualization
import numpy as np
class VisualizationTest(test_util.TestCase):
def test_trajectory_to_images_shape(self):
"""Tests that trajectory_to_images generates a list of images."""
trajectory = np.random.uniform(size=(25, 32, 48))
list_of_images = visualization.trajectory_to_images(trajectory)
self.assertEqual(len(list_of_images), trajectory.shape[0])
self.assertEqual(list_of_images[0].size, (48, 32))
list_of_images = visualization.trajectory_to_images(
trajectory, longest_side=96)
self.assertEqual(len(list_of_images), trajectory.shape[0])
self.assertEqual(list_of_images[0].size, (96, 64))
def test_horizontal_facet_shape(self):
"""Tests that horizontal_facet generates images of expected size."""
trajectory_a = np.random.uniform(size=(25, 32, 32))
trajectory_b = np.random.uniform(size=(25, 32, 32))
relative_separation_width = 0.25
list_of_images_a = visualization.trajectory_to_images(trajectory_a)
list_of_images_b = visualization.trajectory_to_images(trajectory_b)
list_of_images_facet = visualization.horizontal_facet(
[list_of_images_a, list_of_images_b], relative_separation_width)
actual_width = list_of_images_facet[0].width
expected_width = 32 * 2 + int(32 * relative_separation_width)
self.assertEqual(actual_width, expected_width)
def test_save_movie_local(self):
"""Tests that save_movie write gif to a file."""
temp_dir = self.create_tempdir()
temp_filename = os.path.join(temp_dir, 'tmp_file.gif')
input_trajectory = np.random.uniform(size=(25, 32, 32))
images = visualization.trajectory_to_images(input_trajectory)
visualization.save_movie(images, temp_filename)
self.assertTrue(os.path.exists(temp_filename))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for xarray datasets, naming and metadata.
Note on metadata conventions:
When we store data onto xarray.Dataset objects, we are (currently) a little
sloppy about coordinate metadata: we store only a single array for each set of
coordinate values, even though components of our velocity fields are typically
staggered. This is convenient for quick-and-dirty analytics, but means that
variables at the "same" coordinates location may actually be dislocated by any
offset within the unit cell.
"""
import functools
from typing import Any, Dict, Mapping, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from jax_cfd.base import grids
import numpy as np
import pandas
import xarray
Array = grids.Array
GridArray = grids.GridArray
GridVariable = grids.GridVariable
# pytype complains about valid operations with xarray (e.g., see b/153704639),
# so it isn't worth the trouble of running it.
# pytype: skip-file
#
# xarray `Dataset` names for coordinates and attributes.
#
XR_VELOCITY_NAMES = ('u', 'v', 'w')
XR_SCALAR_NAMES = ('c')
XR_SPATIAL_DIMS = ('x', 'y', 'z')
XR_WAVENUMBER_DIMS = ('kx', 'ky', 'kz')
XR_SAMPLE_NAME = 'sample'
XR_TIME_NAME = 'time'
XR_OFFSET_NAME = 'offset'
XR_SAVE_GRID_SIZE_ATTR_NAME = 'save_grid_size'
XR_SAVE_GRID_SIZE_ATTR_NAME_RECTANGLE = ('save_grid_size_x', 'save_grid_size_y')
XR_DOMAIN_SIZE_NAME = 'domain_size'
XR_NDIM_ATTR_NAME = 'ndim'
XR_STABLE_TIME_STEP_ATTR_NAME = 'stable_time_step'
def velocity_trajectory_to_xarray(
trajectory: Tuple[Union[Array, GridArray, GridVariable], ...],
grid: grids.Grid = None,
time: np.ndarray = None,
attrs: Dict[str, Any] = None,
samples: bool = False,
prefix_name: str = '',
) -> xarray.Dataset:
"""Convert a trajectory of velocities to an xarray `Dataset`."""
dimension = len(trajectory)
if grid is not None:
dimension = grid.ndim
dims = (((XR_SAMPLE_NAME,) if samples else ())
+ (XR_TIME_NAME,)
+ XR_SPATIAL_DIMS[:dimension])
data_vars = {}
num_scalars = len(trajectory) - dimension
for component in range(num_scalars):
name = XR_SCALAR_NAMES[component]
data = trajectory[component]
var_attrs = {}
if isinstance(data, GridArray) or isinstance(data, GridVariable):
var_attrs[XR_OFFSET_NAME] = data.offset
data = data.data
data_vars[prefix_name + name] = xarray.Variable(dims, data, var_attrs)
for component in range(dimension):
name = XR_VELOCITY_NAMES[component]
data = trajectory[component + num_scalars]
if isinstance(data, GridArray) or isinstance(data, GridVariable):
data = data.data
var_attrs = {}
if grid is not None:
var_attrs[XR_OFFSET_NAME] = grid.cell_faces[component]
data_vars[prefix_name + name] = xarray.Variable(dims, data, var_attrs)
if samples:
num_samples = next(iter(data_vars.values())).shape[0]
sample_ids = np.arange(num_samples)
else:
sample_ids = None
coords = construct_coords(grid, time, sample_ids)
return xarray.Dataset(data_vars, coords, attrs)
def construct_coords(
grid: Optional[grids.Grid] = None,
times: Optional[np.ndarray] = None,
sample_ids: Optional[np.ndarray] = None,
) -> Mapping[str, np.ndarray]:
"""Create coordinate arrays."""
coords = {}
if grid is not None:
axes = grid.axes(grid.cell_center)
coords.update({dim: ax for dim, ax in zip(XR_SPATIAL_DIMS, axes)})
if times is not None:
coords[XR_TIME_NAME] = times
if sample_ids is not None:
coords[XR_SAMPLE_NAME] = sample_ids
return coords
def grid_from_attrs(dataset_attrs) -> grids.Grid:
"""Constructs a `Grid` object from dataset attributes."""
ndim = dataset_attrs[XR_NDIM_ATTR_NAME]
if XR_SAVE_GRID_SIZE_ATTR_NAME in dataset_attrs:
grid_size = dataset_attrs[XR_SAVE_GRID_SIZE_ATTR_NAME]
grid_shape = (grid_size,) * ndim
if XR_DOMAIN_SIZE_NAME in dataset_attrs:
domain_size = dataset_attrs[XR_DOMAIN_SIZE_NAME]
elif 'domain_size_multiple' in dataset_attrs:
# TODO(shoyer): remove this legacy case, once we no longer use datasets
# generated prior to 2020-09-18
domain_size = 2 * np.pi * dataset_attrs['domain_size_multiple']
else:
raise ValueError(
f'could not figure out domain size from attrs:\n{dataset_attrs}')
grid_domain = [(0, domain_size)] * ndim
else:
grid_shape = tuple(dataset_attrs[attr]
for attr in XR_SAVE_GRID_SIZE_ATTR_NAME_RECTANGLE[:ndim])
aspect_ratio = dataset_attrs['aspect_ratio']
domain_z = (0, 1)
domain_x = (0, aspect_ratio)
grid_domain = (domain_x, domain_z)
grid = grids.Grid(grid_shape, domain=grid_domain)
return grid
def vorticity_2d(ds: xarray.Dataset) -> xarray.DataArray:
"""Calculate vorticity on a 2D dataset."""
# Vorticity is calculated from staggered velocities at offset=(1, 1).
dy = ds.y[1] - ds.y[0]
dx = ds.x[1] - ds.x[0]
dv_dx = (ds.v.roll(x=-1, roll_coords=False) - ds.v) / dx
du_dy = (ds.u.roll(y=-1, roll_coords=False) - ds.u) / dy
return (dv_dx - du_dy).rename('vorticity')
def enstrophy_2d(ds: xarray.Dataset) -> xarray.DataArray:
"""Calculate entrosphy over a 2D dataset."""
return (vorticity_2d(ds) ** 2 / 2).rename('enstrophy')
def magnitude(
u: xarray.DataArray,
v: Optional[xarray.DataArray] = None,
w: Optional[xarray.DataArray] = None,
) -> xarray.DataArray:
"""Calculate the magnitude of a velocity field."""
total = sum((c * c.conj()).real for c in [u, v, w] if c is not None)
return total ** 0.5
def speed(ds: xarray.Dataset) -> xarray.DataArray:
"""Calculate speed at each point in a velocity field."""
args = [ds[k] for k in XR_VELOCITY_NAMES if k in ds]
return magnitude(*args).rename('speed')
def kinetic_energy(ds: xarray.Dataset) -> xarray.DataArray:
"""Calculate kinetic energy at each point in a velocity field."""
return (speed(ds) ** 2 / 2).rename('kinetic_energy')
def fourier_transform(array: xarray.DataArray) -> xarray.DataArray:
"""Calculate the fourier transform of an array, with labeled coordinates."""
# TODO(shoyer): consider switching to use xrft? https://github.com/xgcm/xrft
dims = [dim for dim in XR_SPATIAL_DIMS if dim in array.dims]
axes = [-1, -2, -3][:len(dims)]
result = xarray.apply_ufunc(
functools.partial(np.fft.fftn, axes=axes), array,
input_core_dims=[dims],
output_core_dims=[['k' + d for d in dims]],
output_sizes={'k' + d: array.sizes[d] for d in dims},
output_dtypes=[np.complex128],
dask='parallelized')
for d in dims:
step = float(array.coords[d][1] - array.coords[d][0])
freqs = 2 * np.pi * np.fft.fftfreq(array.sizes[d], step)
result.coords['k' + d] = freqs
# Ensure frequencies are in ascending order (equivalent to fftshift)
rolls = {'k' + d: array.sizes[d] // 2 for d in dims}
return result.roll(rolls, roll_coords=True)
def periodic_correlate(u, v):
"""Periodic correlation of arrays `u`, `v`, implemented using the FFT."""
return np.fft.ifft(np.fft.fft(u).conj() * np.fft.fft(v)).real
def spatial_autocorrelation(array, spatial_axis='x'):
"""Computes spatial autocorrelation of `array` along `spatial_axis`."""
spatial_axis_size = array.sizes[spatial_axis]
out_axis_name = 'd' + spatial_axis
full_result = xarray.apply_ufunc(
lambda x: periodic_correlate(x, x) / spatial_axis_size, array,
input_core_dims=[[spatial_axis]],
output_core_dims=[[out_axis_name]])
# we only report the unique half of the autocorrelation.
num_unique_displacements = spatial_axis_size // 2
result = full_result.isel({out_axis_name: slice(0, num_unique_displacements)})
displacement_coords = array.coords[spatial_axis][:num_unique_displacements]
result.coords[out_axis_name] = (out_axis_name, displacement_coords)
return result
@functools.partial(jax.jit, static_argnums=(0,), backend='cpu')
def _jax_numpy_add_at_zeros(shape, indices, values):
result = jnp.zeros(shape, dtype=values.dtype)
# equivalent to np.add.at(result, (..., indices), array), but much faster
return result.at[..., indices].add(values)
def _binned_sum_numpy(
array: np.ndarray,
indices: np.ndarray,
num_bins: int,
) -> np.ndarray:
"""NumPy helper function for summing over bins."""
mask = np.logical_not(np.isnan(indices))
int_indices = indices[mask].astype(int)
shape = array.shape[:-indices.ndim] + (num_bins,)
result = _jax_numpy_add_at_zeros(shape, int_indices, array[..., mask])
return np.asarray(result)
def groupby_bins_sum(
array: xarray.DataArray,
group: xarray.DataArray,
bins: np.ndarray,
labels: np.ndarray,
) -> xarray.DataArray:
"""Faster equivalent of Xarray's groupby_bins(...).sum()."""
# TODO(shoyer): remove this in favor of groupby_bin() once xarray's
# implementation is improved: https://github.com/pydata/xarray/issues/4473
bin_name = group.name + '_bins'
indices = group.copy(
data=pandas.cut(np.ravel(group), bins, labels=False).reshape(group.shape)
)
result = xarray.apply_ufunc(
_binned_sum_numpy, array, indices,
input_core_dims=[indices.dims, indices.dims],
output_core_dims=[[bin_name]],
output_dtypes=[array.dtype],
output_sizes={bin_name: labels.size},
kwargs={'num_bins': bins.size - 1},
dask='parallelized',
)
result[bin_name] = labels
return result
def _isotropize_binsum(ndim, energy):
"""Calculate energy spectrum summing over bins in wavenumber space."""
wavenumbers = [energy[name] for name in XR_WAVENUMBER_DIMS[:ndim]]
k_spacing = max(float(k[1] - k[0]) for k in wavenumbers)
k_max = min(float(w.max()) for w in wavenumbers) - 0.5 * k_spacing
k = magnitude(*wavenumbers).rename('k')
bounds = k_spacing * (np.arange(1, round(k_max / k_spacing) + 2) - 0.5)
labels = k_spacing * np.arange(1, round(k_max / k_spacing) + 1)
binned = groupby_bins_sum(energy, k, bounds, labels)
spectrum = binned.rename(k_bins='k')
return spectrum
def _isotropize_interpolation_2d(
energy, interpolation_method, num_quadrature_points,
):
"""Caclulate energy spectrum of a 2D signal with interpolation."""
# Calculate even spaced discrete levels for wavenumber magnitude
wavenumbers = [energy[name] for name in XR_WAVENUMBER_DIMS[:2]]
k_spacing = max(float(k[1] - k[0]) for k in wavenumbers)
k_max = min(float(w.max()) for w in wavenumbers) - 0.5 * k_spacing
k_values = k_spacing * np.arange(1, round(k_max / k_spacing) + 1)
k = xarray.DataArray(k_values, dims='k', coords={'k': k_values})
angle_values = np.linspace(
0, 2 * np.pi, num=num_quadrature_points, endpoint=False)
angle = xarray.DataArray(angle_values, dims='angle')
# Sample the spectrum at each point on the boundary of the circle with
# with radius k
kx = k * np.cos(angle)
ky = k * np.sin(angle)
# Interpolation on log(energy), which is much smoother in wavenumber space
# than the energy itself (which decays quite rapidly)
density = np.exp(
np.log(energy).interp(kx=kx, ky=ky, method=interpolation_method)
)
# Integrate over the edge of each circle
spectrum = 2 * np.pi * k * density.mean('angle')
return spectrum
def isotropize(
array: xarray.DataArray,
method: Optional[str] = None,
interpolation_method: str = 'linear',
num_quadrature_points: int = 100,
) -> xarray.DataArray:
"""Isotropize an ND spectrum by averaging over all angles.
Args:
array: array to isotropically average, with one or more dimensions
correspondings to wavenumbers.
method: either "interpolation" or "binsum".
interpolation_method: either "linear" or "nearest". Only used if using
method="interpolation".
num_quadrature_points: number of points to use when integrating over
wavenumbers with method="interpolation".
Returns:
Energy spectra as a function of wavenumber magnitude.
"""
ndim = sum(dim in array.dims for dim in XR_WAVENUMBER_DIMS)
if ndim == 0:
raise ValueError(f'no frequency dimensions found: {array.dims}')
if method is None:
method = 'interpolation' if ndim == 2 else 'binsum'
if method == 'interpolation':
if ndim != 2:
raise ValueError('interpolation not yet supported for non-2D inputs')
# TODO(shoyer): switch to more accurate algorithms for both 1D and 3D, too:
# - 1D can simply add up the energy at positive and negative frequencies
# - 3D can efficiently integrate over all angles using Lebedev quadrature:
# https://en.wikipedia.org/wiki/Lebedev_quadrature
return _isotropize_interpolation_2d(
array, interpolation_method, num_quadrature_points)
elif method == 'binsum':
# NOTE(shoyer): I believe this function is equivalent to
# xrft.isotropize(), but is faster & more efficient because we
# use groupby_bins_sum(). See https://github.com/xgcm/xrft/issues/9
return _isotropize_binsum(ndim, array)
else:
raise ValueError(f'invalid method: {method}')
def energy_spectrum(ds: xarray.Dataset) -> xarray.DataArray:
"""Calculate the kinetic energy spectra at each wavenumber.
Args:
ds: dataset with `u`, `v` and/or `w` velocity components and corresponding
spatial dimensions.
Returns:
Energy spectra as a function of wavenumber instead of space.
"""
ndim = sum(dim in ds.dims for dim in 'xyz')
velocity_components = list(XR_VELOCITY_NAMES[:ndim])
fourier_ds = ds[velocity_components].map(fourier_transform)
return kinetic_energy(fourier_ds)
def isotropic_energy_spectrum(
ds: xarray.Dataset,
average_dims: Tuple[str, ...] = (),
) -> xarray.DataArray:
"""Calculate the energy spectra at each scalar wavenumber.
Args:
ds: dataset with `u`, `v` and/or `w` velocity components and corresponding
spatial dimensions.
average_dims: dimensions to average over before isotropic averaging.
Returns:
Energy spectra as a function of wavenumber magnitude, without spatial
dimensions.
"""
return isotropize(energy_spectrum(ds).mean(average_dims))
def velocity_spatial_correlation(
ds: xarray.Dataset,
axis: str
) ->xarray.Dataset:
"""Computes velocity correlation along `axis` for all velocity components."""
ndim = sum(dim in ds.dims for dim in 'xyz')
velocity_components = list(XR_VELOCITY_NAMES[:ndim])
correlation_fn = lambda x: spatial_autocorrelation(x, axis)
correlations = ds[velocity_components].map(correlation_fn)
name_mapping = {u: '_'.join([u, axis, 'correlation'])
for u in velocity_components}
return correlations.rename(name_mapping)
def normalize(array: xarray.DataArray, state_dims: Tuple[str, ...]):
"""Returns `array` with slices along `state_dims` normalized to unity."""
norm = np.sqrt((array ** 2).sum(state_dims))
return array / norm
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.data.xarray_utils."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax_cfd.base import test_util
from jax_cfd.data import xarray_utils
import numpy as np
import xarray
class XarrayUtilsTest(test_util.TestCase):
"""Tests utility functions interacting with xarray."""
@parameterized.parameters(
dict(all_dims=('time', 'x', 'y', 'sample'), state_dims=('x', 'y'),),
dict(all_dims=('x', 'y', 'z', 'sample'), state_dims=('x', 'z', 'y'),),
dict(all_dims=('time', 'x'), state_dims=('x'),),
dict(all_dims=('x', 'sample', 'y'), state_dims=('x', 'y'),),
dict(all_dims=('x', 'z', 'y'), state_dims=('x', 'y', 'z'),),
)
def test_normalize(self, all_dims, state_dims):
"""Tests that `normalize` returns data with expected shapes and norms."""
self.skipTest("test is sensitive to its random seed")
shape_key, value_key = jax.random.split(jax.random.PRNGKey(42), 2)
input_shape = jax.random.randint(shape_key, (len(all_dims),), 1, 4)
inputs = jax.random.normal(value_key, input_shape)
non_state_dims = [dim for dim in all_dims if dim not in state_dims]
get_dim_axis_fn = lambda dim: np.where(np.asarray(all_dims) == dim)[0][0]
state_axes = [get_dim_axis_fn(dim) for dim in state_dims]
# to compute expected values we move state dimensions to the first axes,
# divide by the norm and then reshape back.
inputs_ordered = np.moveaxis(inputs, state_axes, np.arange(len(state_axes)))
vec_shape = (-1,) + inputs_ordered.shape[-len(non_state_dims):]
inputs_vec = np.reshape(inputs_ordered, vec_shape)
inputs_vec = inputs_vec / np.linalg.norm(inputs_vec, axis=0)
normalized = np.reshape(inputs_vec, inputs_ordered.shape)
expected = np.moveaxis(normalized, np.arange(len(state_axes)), state_axes)
coords = {dim: np.arange(input_shape[i]) for i, dim in enumerate(all_dims)}
array = xarray.DataArray(inputs, coords, all_dims)
normalized_array = xarray_utils.normalize(array, state_dims)
actual = normalized_array.transpose(*all_dims).values
self.assertAllClose(expected, actual, atol=1e-6)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An ML modeling library built on Haiku and Gin-Config for JAX-CFD."""
import jax_cfd.ml.advections
import jax_cfd.ml.decoders
import jax_cfd.ml.diffusions
import jax_cfd.ml.encoders
import jax_cfd.ml.equations
import jax_cfd.ml.forcings
import jax_cfd.ml.interpolations
import jax_cfd.ml.layers
import jax_cfd.ml.layers_util
import jax_cfd.ml.model_builder
import jax_cfd.ml.model_utils
import jax_cfd.ml.networks
import jax_cfd.ml.nonlinearities
import jax_cfd.ml.optimizer_modules
import jax_cfd.ml.physics_specifications
import jax_cfd.ml.pressures
import jax_cfd.ml.tiling
import jax_cfd.ml.time_integrators
import jax_cfd.ml.towers
import jax_cfd.ml.viscosities
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment