Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for functionality related to advection."""
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import interpolation
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = interpolation.InterpolationFn
# TODO(dkochkov) Consider testing if we need operator splitting methods.
def _advect_aligned(cs: GridVariableVector, v: GridVariableVector) -> GridArray:
"""Computes fluxes and the associated advection for aligned `cs` and `v`.
The values `cs` should consist of a single quantity `c` that has been
interpolated to the offset of the components of `v`. The components of `v` and
`cs` should be located at the faces of a single (possibly offset) grid cell.
We compute the advection as the divergence of the flux on this control volume.
The boundary condition on the flux is inherited from the scalar quantity `c`.
A typical example in three dimensions would have
```
cs[0].offset == v[0].offset == (1., .5, .5)
cs[1].offset == v[1].offset == (.5, 1., .5)
cs[2].offset == v[2].offset == (.5, .5, 1.)
```
In this case, the returned advection term would have offset `(.5, .5, .5)`.
Args:
cs: a sequence of `GridArray`s; a single value `c` that has been
interpolated so that it is aligned with each component of `v`.
v: a sequence of `GridArrays` describing a velocity field. Should be defined
on the same Grid as cs.
Returns:
An `GridArray` containing the time derivative of `c` due to advection by
`v`.
Raises:
ValueError: `cs` and `v` have different numbers of components.
AlignmentError: if the components of `cs` are not aligned with those of `v`.
"""
# TODO(jamieas): add more sophisticated alignment checks, ensuring that the
# values are located on the faces of a control volume.
if len(cs) != len(v):
raise ValueError('`cs` and `v` must have the same length;'
f'got {len(cs)} vs. {len(v)}.')
flux = tuple(c.array * u.array for c, u in zip(cs, v))
bcs = tuple(
boundaries.get_advection_flux_bc_from_velocity_and_scalar(v[i], cs[i], i)
for i in range(len(v)))
flux = tuple(bc.impose_bc(f) for f, bc in zip(flux, bcs))
return -fd.divergence(flux)
def advect_general(
c: GridVariable,
v: GridVariableVector,
u_interpolation_fn: InterpolationFn,
c_interpolation_fn: InterpolationFn,
dt: Optional[float] = None) -> GridArray:
"""Computes advection of a scalar quantity `c` by the velocity field `v`.
This function follows the following procedure:
1. Interpolate each component of `v` to the corresponding face of the
control volume centered on `c`.
2. Interpolate `c` to the same control volume faces.
3. Compute the flux `cu` using the aligned values.
4. Set the boundary condition on flux, which is inhereited from `c`.
5. Return the negative divergence of the flux.
Args:
c: the quantity to be transported.
v: a velocity field. Should be defined on the same Grid as c.
u_interpolation_fn: method for interpolating velocity field `v`.
c_interpolation_fn: method for interpolating scalar field `c`.
dt: unused time-step.
Returns:
The time derivative of `c` due to advection by `v`.
"""
if not boundaries.has_all_periodic_boundary_conditions(c):
raise NotImplementedError(
'Non-periodic boundary conditions are not implemented.')
target_offsets = grids.control_volume_offsets(c)
aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt)
for u, target_offset in zip(v, target_offsets))
aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
for target_offset in target_offsets)
return _advect_aligned(aligned_c, aligned_v)
def advect_linear(c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None) -> GridArray:
"""Computes advection using linear interpolations."""
return advect_general(c, v, interpolation.linear, interpolation.linear, dt)
def advect_upwind(c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None) -> GridArray:
"""Computes advection using first-order upwind interpolation on `c`."""
return advect_general(c, v, interpolation.linear, interpolation.upwind, dt)
def _align_velocities(v: GridVariableVector) -> Tuple[GridVariableVector]:
"""Returns interpolated components of `v` needed for convection.
Args:
v: a velocity field.
Returns:
A d-tuple of d-tuples of `GridVariable`s `aligned_v`, where `d = len(v)`.
The entry `aligned_v[i][j]` is the component `v[i]` after interpolation to
the appropriate face of the control volume centered around `v[j]`.
"""
grid = grids.consistent_grid(*v)
offsets = tuple(grids.control_volume_offsets(u) for u in v)
aligned_v = tuple(
tuple(interpolation.linear(v[i], offsets[i][j])
for j in range(grid.ndim))
for i in range(grid.ndim))
return aligned_v
def _velocities_to_flux(
aligned_v: Tuple[GridVariableVector]) -> Tuple[GridVariableVector]:
"""Computes the fluxes across the control volume faces for a velocity field.
This is the flux associated with the nonlinear term `vv` for velocity `v`.
The boundary condition on the flux is inherited from `v`.
Args:
aligned_v: a d-tuple of d-tuples of `GridVariable`s such that the entry
`aligned_v[i][j]` is the component `v[i]` after interpolation to
the appropriate face of the control volume centered around `v[j]`. This is
the output of `_align_velocities`.
Returns:
A tuple of tuples `flux` of `GridVariable`s with the same structure as
`aligned_v`. The entry `flux[i][j]` is `aligned_v[i][j] * aligned_v[j][i]`.
"""
ndim = len(aligned_v)
flux = [tuple() for _ in range(ndim)]
for i in range(ndim):
for j in range(ndim):
if i <= j:
bc = boundaries.get_advection_flux_bc_from_velocity_and_scalar(
aligned_v[j][i], aligned_v[i][j], j)
flux[i] += (bc.impose_bc(aligned_v[i][j].array *
aligned_v[j][i].array),)
else:
flux[i] += (flux[j][i],)
return tuple(flux)
def convect_linear(v: GridVariableVector) -> GridArrayVector:
"""Computes convection/self-advection of the velocity field `v`.
This function is conceptually equivalent to
```
def convect_linear(v, grid):
return tuple(advect_linear(u, v, grid) for u in v)
```
However, there are several optimizations to avoid duplicate operations.
Args:
v: a velocity field.
Returns:
A tuple containing the time derivative of each component of `v` due to
convection.
"""
# TODO(jamieas): consider a more efficient vectorization of this function.
# TODO(jamieas): incorporate variable density.
aligned_v = _align_velocities(v)
fluxes = _velocities_to_flux(aligned_v)
return tuple(-fd.divergence(flux) for flux in fluxes)
def advect_van_leer(
c: GridVariable,
v: GridVariableVector,
dt: float,
mode: str = boundaries.Padding.MIRROR,
) -> GridArray:
"""Computes advection of a scalar quantity `c` by the velocity field `v`.
Implements Van-Leer flux limiting scheme that uses second order accurate
approximation of fluxes for smooth regions of the solution. This scheme is
total variation diminishing (TVD). For regions with high gradients flux
limitor transformes the scheme into a first order method. For [1] for
reference. This function follows the following procedure:
1. Shifts c to offset < 1 if necessary.
2. Scalar c now has a well defined right-hand (upwind) value.
3. Computes upwind flux for each direction.
4. Computes van leer flux limiter:
a. Use the shifted c to interpolate each component of `v` to the
right-hand (upwind) face of the control volume centered on `c`.
b. Compute the ratio of successive gradients:
In nonperiodic case, the value outside the boundary is not defined.
Mode is used to interpolate past the boundary.
c. Compute flux limiter function.
d. Computes higher order flux correction.
5. Combines fluxes and assigns flux boundary condition.
6. Computes the negative divergence of fluxes.
7. Shifts the computed values back to original offset of c.
Args:
c: the quantity to be transported.
v: a velocity field. Should be defined on the same Grid as c.
dt: time step for which this scheme is TVD and second order accurate
in time.
mode: For non-periodic BC, specifies extrapolation of values beyond the
boundary, which is used by nonlinear interpolation.
Returns:
The time derivative of `c` due to advection by `v`.
#### References
[1]: MIT 18.336 spring 2009 Finite Volume Methods Lecture 19.
go/mit-18.336-finite_volume_methods-19
[2]:
www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf
"""
# TODO(dkochkov) reimplement this using apply_limiter method.
c_left_var = c
# if the offset is 1., shift by 1 to offset 0.
# otherwise c_right is not defined.
for ax in range(c.grid.ndim):
# int(c.offset[ax] % 1 - c.offset[ax]) = -1 if c.offset[ax] is 1 else
# int(c.offset[ax] % 1 - c.offset[ax]) = 0.
# i.e. this shifts the 1 aligned data to 0 offset, the rest is unchanged.
c_left_var = c.bc.impose_bc(
c_left_var.shift(int(c.offset[ax] % 1 - c.offset[ax]), axis=ax))
offsets = grids.control_volume_offsets(c_left_var)
# if c offset is 0, aligned_v is at 0.5.
# if c offset is at .5, aligned_v is at 1.
aligned_v = tuple(interpolation.linear(u, offset)
for u, offset in zip(v, offsets))
flux = []
# Assign flux boundary condition
flux_bc = [
boundaries.get_advection_flux_bc_from_velocity_and_scalar(u, c, direction)
for direction, u in enumerate(v)
]
# first, compute upwind flux.
for axis, u in enumerate(aligned_v):
c_center = c_left_var.data
# by shifting c_left + 1, c_right is well-defined.
c_right = c_left_var.shift(+1, axis=axis).data
upwind_flux = grids.applied(jnp.where)(
u.array > 0, u.array * c_center, u.array * c_right)
flux.append(upwind_flux)
# next, compute van_leer correction.
for axis, (u, h) in enumerate(zip(aligned_v, c.grid.step)):
u = u.bc.shift(u.array, int(u.offset[axis] % 1 - u.offset[axis]), axis=axis)
# c is put to offset .5 or 1.
c_center_arr = c.shift(int(1 - c.offset[ax]), axis=ax)
# if c offset is 1, u offset is .5.
# if c offset is .5, u offset is 0.
# u_i is always on the left of c_center_var_i
c_center = c_center_arr.data
# shift -1 are well defined now
# shift +1 is not well defined for c offset 1 because then c(wall + 1) is
# not defined.
# However, the flux that uses c(wall + 1) offset gets overridden anyways
# when flux boundary condition is overridden.
# Thus, any mode can be used here.
c_right = c.bc.shift(c_center_arr, +1, axis=axis, mode=mode).data
c_left = c.bc.shift(c_center_arr, -1, axis=axis).data
# shift -2 is tricky:
# It is well defined if c is periodic.
# Else, c(-1) or c(-1.5) are not defined.
# Then, mode is used to interpolate the values.
c_left_left = c.bc.shift(
c_center_arr, -2, axis, mode=mode).data
numerator_positive = c_left - c_left_left
numerator_negative = c_right - c_center
numerator = grids.applied(jnp.where)(u > 0, numerator_positive,
numerator_negative)
denominator = grids.GridArray(c_center - c_left, u.offset, u.grid)
# We want to calculate denominator / (abs(denominator) + abs(numerator))
# To make it differentiable, it needs to be done in stages.
# ensures that there is no division by 0
phi_van_leer_denominator_avoid_nans = grids.applied(jnp.where)(
abs(denominator) > 0, (abs(denominator) + abs(numerator)), 1.)
phi_van_leer_denominator_inv = denominator / phi_van_leer_denominator_avoid_nans
phi_van_leer = numerator * (grids.applied(jnp.sign)(denominator) +
grids.applied(jnp.sign)
(numerator)) * phi_van_leer_denominator_inv
abs_velocity = abs(u)
courant_numbers = (dt / h) * abs_velocity
pre_factor = 0.5 * (1 - courant_numbers) * abs_velocity
flux_correction = pre_factor * phi_van_leer
# Shift back onto original offset.
flux_correction = flux_bc[axis].shift(
flux_correction, int(offsets[axis][axis] - u.offset[axis]), axis=axis)
flux[axis] += flux_correction
flux = tuple(flux_bc[axis].impose_bc(f) for axis, f in enumerate(flux))
advection = -fd.divergence(flux)
# shift the variable back onto the original offset
for ax in range(c.grid.ndim):
advection = c.bc.shift(
advection, -int(c.offset[ax] % 1 - c.offset[ax]), axis=ax)
return advection
def advect_step_semilagrangian(
c: GridVariable,
v: GridVariableVector,
dt: float
) -> GridVariable:
"""Semi-Lagrangian advection of a scalar quantity.
Note that unlike the other advection functions, this function returns values
at the next time-step, not the time derivative.
Args:
c: the quantity to be transported.
v: a velocity field. Should be defined on the same Grid as c.
dt: desired time-step.
Returns:
Advected quantity at the next time-step -- *not* the time derivative.
"""
# Reference: "Learning to control PDEs with Differentiable Physics"
# https://openreview.net/pdf?id=HyeSin4FPB (see Appendix A)
grid = grids.consistent_grid(c, *v)
# TODO(shoyer) Enable lower domains != 0 for this function.
# Hint: indices = [
# -o + (x - l) * n / (u - l)
# for (l, u), o, x, n in zip(grid.domain, c.offset, coords, grid.shape)
# ]
if not all(d[0] == 0 for d in grid.domain):
raise ValueError(
f'Grid domains currently must start at zero. Found {grid.domain}')
coords = [x - dt * interpolation.linear(u, c.offset).data
for x, u in zip(grid.mesh(c.offset), v)]
indices = [x / s - o for s, o, x in zip(grid.step, c.offset, coords)]
if not boundaries.has_all_periodic_boundary_conditions(c):
raise NotImplementedError('non-periodic BCs not yet supported')
c_advected = grids.applied(jax.scipy.ndimage.map_coordinates)(
c.array, indices, order=1, mode='wrap')
return GridVariable(c_advected, c.bc)
# TODO(dkochkov) Implement advect_with_flux_limiter method.
# TODO(dkochkov) Consider moving `advect_van_leer` to test based on performance.
def advect_van_leer_using_limiters(
c: GridVariable,
v: GridVariableVector,
dt: float
) -> GridArray:
"""Implements Van-Leer advection by applying TVD limiter to Lax-Wendroff."""
c_interpolation_fn = interpolation.apply_tvd_limiter(
interpolation.lax_wendroff, limiter=interpolation.van_leer_limiter)
return advect_general(c, v, interpolation.linear, c_interpolation_fn, dt)
def stable_time_step(max_velocity: float,
max_courant_number: float,
grid: grids.Grid) -> float:
"""Calculate a stable time step size for explicit advection.
The calculation is based on the CFL condition for advection.
Args:
max_velocity: maximum velocity.
max_courant_number: the Courant number used to choose the time step. Smaller
numbers will lead to more stable simulations. Typically this should be in
the range [0.5, 1).
grid: a `Grid` object.
Returns:
The prescribed time interval.
"""
dx = min(grid.step)
return max_courant_number * dx / max_velocity
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.advection."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import boundaries
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
def _gaussian_concentration(grid):
offset = tuple(-int(jnp.ceil(s / 2.)) for s in grid.shape)
return grids.GridArray(
jnp.exp(-sum(jnp.square(m) * 30. for m in grid.mesh(offset=offset))),
(0.5,) * len(grid.shape), grid)
def _square_concentration(grid):
select_square = lambda x: jnp.where(jnp.logical_and(x > 0.4, x < 0.6), 1., 0.)
return grids.GridArray(
jnp.array([select_square(m) for m in grid.mesh()]).prod(0),
(0.5,) * len(grid.shape), grid)
def _unit_velocity(grid, velocity_sign=1.):
ndim = grid.ndim
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
return tuple(
grids.GridArray(velocity_sign * jnp.ones(grid.shape) if ax == 0
else jnp.zeros(grid.shape), tuple(offset), grid)
for ax, offset in enumerate(offsets))
def _cos_velocity(grid):
ndim = grid.ndim
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
mesh = grid.mesh()
v = tuple(grids.GridArray(jnp.cos(mesh[i] * 2. * np.pi), tuple(offset), grid)
for i, offset in enumerate(offsets))
return v
def _velocity_implicit(grid, offset, u, t):
"""Returns solution of a Burgers equation at time `t`."""
x = grid.mesh((offset,))[0]
return grids.GridArray(jnp.sin(x - u * t), (offset,), grid)
def _total_variation(array, motion_axis):
next_values = array.shift(1, motion_axis)
variation = jnp.sum(jnp.abs(next_values.data - array.data))
return variation
def _euler_step(advection_method):
def step(c, v, dt):
c_new = c.array + dt * advection_method(c, v, dt)
return c.bc.impose_bc(c_new)
return step
class AdvectionTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='linear_1D',
shape=(101,),
method=_euler_step(advection.advect_linear),
num_steps=1000,
cfl_number=0.01,
atol=5e-2),
dict(testcase_name='linear_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_linear),
num_steps=1000,
cfl_number=0.01,
atol=5e-2),
dict(testcase_name='upwind_1D',
shape=(101,),
method=_euler_step(advection.advect_upwind),
num_steps=100,
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='upwind_3D',
shape=(101, 5, 5),
method=_euler_step(advection.advect_upwind),
num_steps=100,
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_1D_negative_v',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
num_steps=100,
cfl_number=0.5,
atol=2e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 5, 5),
method=_euler_step(advection.advect_van_leer),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_using_limiters_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer_using_limiters),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_using_limiters_3D',
shape=(101, 5, 5),
method=_euler_step(advection.advect_van_leer_using_limiters),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='semilagrangian_1D',
shape=(101,),
method=advection.advect_step_semilagrangian,
num_steps=100,
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='semilagrangian_3D',
shape=(101, 5, 5),
method=advection.advect_step_semilagrangian,
num_steps=100,
cfl_number=0.5,
atol=7e-2),
)
def test_advection_analytical(
self, shape, method, num_steps, cfl_number, atol, v_sign=1):
"""Tests advection of a Gaussian concentration on a periodic grid."""
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_gaussian_concentration(grid), bc)
dt = cfl_number * min(step)
advect = functools.partial(method, v=v, dt=dt)
evolve = jax.jit(funcutils.repeated(advect, num_steps))
ct = evolve(c)
expected_shift = int(round(-cfl_number * num_steps * v_sign))
expected = c.shift(expected_shift, axis=0).data
self.assertAllClose(expected, ct.data, atol=atol)
@parameterized.named_parameters(
dict(
testcase_name='dirichlet_1d_100', shape=(100,), atol=0.001,
offset=.5),
dict(
testcase_name='dirichlet_1d_200',
shape=(200,),
atol=0.00025,
offset=.5),
dict(
testcase_name='dirichlet_1d_400',
shape=(400,),
atol=0.00007,
offset=.5),
dict(
testcase_name='dirichlet_1d_100_cell_edge_0',
shape=(100,),
atol=0.002,
offset=0.),
dict(
testcase_name='dirichlet_1d_200_cell_edge_0',
shape=(200,),
atol=0.0005,
offset=0.),
dict(
testcase_name='dirichlet_1d_400_cell_edge_0',
shape=(400,),
atol=0.000125,
offset=0.),
dict(
testcase_name='dirichlet_1d_100_cell_edge_1',
shape=(100,),
atol=0.002,
offset=1.),
dict(
testcase_name='dirichlet_1d_200_cell_edge_1',
shape=(200,),
atol=0.0005,
offset=1.),
dict(
testcase_name='dirichlet_1d_400_cell_edge_1',
shape=(400,),
atol=0.000125,
offset=1.),
)
def test_burgers_analytical_dirichlet_convergence(
self,
shape,
atol,
offset,
):
num_steps = 1000
cfl_number = 0.01
step = 2 * jnp.pi / 1000
grid = grids.Grid(shape, domain=([0., 2 * jnp.pi],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
v = (bc.impose_bc(_velocity_implicit(grid, offset, 0, 0)),)
dt = cfl_number * step
def _advect(v):
dv_dt = advection.advect_van_leer(c=v[0], v=v, dt=dt) / 2
return (bc.impose_bc(v[0].array + dt * dv_dt),)
evolve = jax.jit(funcutils.repeated(_advect, num_steps))
ct = evolve(v)
expected = bc.impose_bc(
_velocity_implicit(grid, offset, ct[0].data, dt * num_steps)).data
self.assertAllClose(expected, ct[0].data, atol=atol)
@parameterized.named_parameters(
dict(testcase_name='linear_1D',
shape=(101,),
method=_euler_step(advection.advect_linear),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='linear_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_linear),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='upwind_1D',
shape=(101,),
method=_euler_step(advection.advect_upwind),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='upwind_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_upwind),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_1D_negative_v',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_using_limiters_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer_using_limiters),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_using_limiters_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_van_leer_using_limiters),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='semilagrangian_1D',
shape=(101,),
method=advection.advect_step_semilagrangian,
atol=1e-2,
rtol=1e-2),
dict(testcase_name='semilagrangian_3D',
shape=(101, 101, 101),
method=advection.advect_step_semilagrangian,
atol=1e-2,
rtol=1e-2),
)
def test_advection_gradients(
self, shape, method, atol, rtol, cfl_number=0.5, v_sign=1,
):
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_gaussian_concentration(grid), bc)
dt = cfl_number * min(step)
advect = jax.remat(functools.partial(method, v=v, dt=dt))
evolve = jax.jit(funcutils.repeated(advect, steps=10))
def objective(c):
return 0.5 * jnp.sum(evolve(c).data ** 2)
gradient = jax.jit(jax.grad(objective))(c)
self.assertAllClose(c, gradient, atol=atol, rtol=rtol)
@parameterized.named_parameters(
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_1D_negative_v',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
)
def test_advection_gradients_division_by_zero(
self, shape, method, atol, rtol, cfl_number=0.5, v_sign=1,
):
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_unit_velocity(grid)[0], bc)
dt = cfl_number * min(step)
advect = jax.remat(functools.partial(method, v=v, dt=dt))
evolve = jax.jit(funcutils.repeated(advect, steps=10))
def objective(c):
return 0.5 * jnp.sum(evolve(c).data ** 2)
gradient = jax.jit(jax.grad(objective))(c)
self.assertAllClose(c, gradient, atol=atol, rtol=rtol)
@parameterized.named_parameters(
dict(testcase_name='_linear_1D',
shape=(101,),
advection_method=advection.advect_linear,
convection_method=advection.convect_linear),
dict(testcase_name='_linear_3D',
shape=(101, 101, 101),
advection_method=advection.advect_linear,
convection_method=advection.convect_linear)
)
def test_convection_vs_advection(
self, shape, advection_method, convection_method,
):
"""Exercises self-advection, check equality with advection on components."""
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _cos_velocity(grid))
self_advected = convection_method(v)
for u, du in zip(v, self_advected):
advected_component = advection_method(u, v)
self.assertAllClose(advected_component, du)
@parameterized.named_parameters(
dict(testcase_name='upwind_1D',
shape=(101,),
method=_euler_step(advection.advect_upwind)),
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer)),
dict(testcase_name='semilagrangian_1D',
shape=(101,),
method=advection.advect_step_semilagrangian),
)
def test_tvd_property(self, shape, method):
atol = 1e-6
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid))
c = grids.GridVariable(_square_concentration(grid), bc)
dt = min(step) / 100.
num_steps = 300
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_total_variation = _total_variation(c, 0) + atol
for _ in range(num_steps):
ct = advect(ct)
current_total_variation = _total_variation(ct, 0)
self.assertLessEqual(current_total_variation, initial_total_variation)
@parameterized.named_parameters(
dict(
testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer)),
)
def test_mass_conservation(self, shape, method):
offset = 0.5
cfl_number = 0.1
dt = cfl_number / shape[0]
num_steps = 1000
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),))
def u(grid, offset):
x = grid.mesh((offset,))[0]
return grids.GridArray(-jnp.sin(jnp.pi * x), (offset,), grid)
def c0(grid, offset):
x = grid.mesh((offset,))[0]
return grids.GridArray(x, (offset,), grid)
v = (bc.impose_bc(u(grid, 1.)),)
c = c_bc.impose_bc(c0(grid, offset))
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_mass = np.sum(c.data)
for _ in range(num_steps):
ct = advect(ct)
current_total_mass = np.sum(ct.data)
self.assertAllClose(current_total_mass, initial_mass, atol=1e-6)
@parameterized.named_parameters(
dict(testcase_name='van_leers_equivalence_1d',
shape=(101,), v_sign=1.),
dict(testcase_name='van_leers_equivalence_3d',
shape=(101, 101, 101), v_sign=1.),
dict(testcase_name='van_leers_equivalence_1d_negative_v',
shape=(101,), v_sign=-1.),
dict(testcase_name='van_leers_equivalence_3d_negative_v',
shape=(101, 101, 101), v_sign=-1.),
)
def test_van_leer_same_as_van_leer_using_limiters(self, shape, v_sign):
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_gaussian_concentration(grid), bc)
dt = min(step) / 100.
num_steps = 100
advect_vl = jax.jit(
functools.partial(_euler_step(advection.advect_van_leer), v=v, dt=dt))
advect_vl_using_limiter = jax.jit(
functools.partial(
_euler_step(advection.advect_van_leer_using_limiters), v=v, dt=dt))
c_vl = c
c_vl_using_limiter = c
for _ in range(num_steps):
c_vl = advect_vl(c_vl)
c_vl_using_limiter = advect_vl_using_limiter(c_vl_using_limiter)
self.assertAllClose(c_vl, c_vl_using_limiter, atol=1e-5)
if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods for manipulating array-like objects."""
from typing import Any, Callable, List, Tuple, Union
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
import numpy as np
import scipy.linalg
# There is currently no good way to indicate a jax "pytree" with arrays at its
# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more
# information about PyTrees and https://github.com/google/jax/issues/3340 for
# discussion of this issue.
PyTree = Any
Array = Union[np.ndarray, jax.Array]
def _normalize_axis(axis: int, ndim: int) -> int:
"""Validates and returns positive `axis` value."""
if not -ndim <= axis < ndim:
raise ValueError(f'invalid axis {axis} for ndim {ndim}')
if axis < 0:
axis += ndim
return axis
def slice_along_axis(
inputs: PyTree,
axis: int,
idx: Union[slice, int],
expect_same_dims: bool = True
) -> PyTree:
"""Returns slice of `inputs` defined by `idx` along axis `axis`.
Args:
inputs: array or a tuple of arrays to slice.
axis: axis along which to slice the `inputs`.
idx: index or slice along axis `axis` that is returned.
expect_same_dims: whether all arrays should have same number of dimensions.
Returns:
Slice of `inputs` defined by `idx` along axis `axis`.
"""
# arrays, tree_def = jax.tree_util.flatten(inputs)
arrays, tree_def = jax.tree_util.tree_flatten(inputs)
ndims = set(a.ndim for a in arrays)
if expect_same_dims and len(ndims) != 1:
raise ValueError('arrays in `inputs` expected to have same ndims, but have '
f'{ndims}. To allow this, pass expect_same_dims=False')
sliced = []
for array in arrays:
ndim = array.ndim
slc = tuple(idx if j == _normalize_axis(axis, ndim) else slice(None)
for j in range(ndim))
sliced.append(array[slc])
return jax.tree_util.tree_unflatten(tree_def, sliced)
def split_along_axis(
inputs: PyTree,
split_idx: int,
axis: int,
expect_same_dims: bool = True
) -> Tuple[PyTree, PyTree]:
"""Returns a tuple of slices of `inputs` split along `axis` at `split_idx`.
Args:
inputs: pytree of arrays to split.
split_idx: index along `axis` where the second split starts.
axis: axis along which to split the `inputs`.
expect_same_dims: whether all arrays should have same number of dimensions.
Returns:
Tuple of slices of `inputs` split along `axis` at `split_idx`.
"""
first_slice = slice_along_axis(
inputs, axis, slice(0, split_idx), expect_same_dims)
second_slice = slice_along_axis(
inputs, axis, slice(split_idx, None), expect_same_dims)
return first_slice, second_slice
def split_axis(
inputs: PyTree,
axis: int,
keep_dims: bool = False
) -> Tuple[PyTree, ...]:
"""Splits the arrays in `inputs` along `axis`.
Args:
inputs: pytree to be split.
axis: axis along which to split the `inputs`.
keep_dims: whether to keep `axis` dimension.
Returns:
Tuple of pytrees that correspond to slices of `inputs` along `axis`. The
`axis` dimension is removed if `squeeze is set to True.
Raises:
ValueError: if arrays in `inputs` don't have unique size along `axis`.
"""
arrays, tree_def = jax.tree_util.flatten(inputs)
axis_shapes = set(a.shape[axis] for a in arrays)
if len(axis_shapes) != 1:
raise ValueError(f'Arrays must have equal sized axis but got {axis_shapes}')
axis_shape, = axis_shapes
splits = [jnp.split(a, axis_shape, axis=axis) for a in arrays]
if not keep_dims:
splits = jax.tree_util.tree_map(lambda a: jnp.squeeze(a, axis), splits)
splits = zip(*splits)
return tuple(jax.tree_util.unflatten(tree_def, leaves) for leaves in splits)
def concat_along_axis(pytrees, axis):
"""Concatenates `pytrees` along `axis`."""
concat_leaves_fn = lambda *args: jnp.concatenate(args, axis)
return jax.tree_util.tree_map(concat_leaves_fn, *pytrees)
def block_reduce(
array: Array,
block_size: Tuple[int, ...],
reduction_fn: Callable[[Array], Array]
) -> Array:
"""Breaks `array` into `block_size` pieces and applies `f` to each.
This function is equivalent to `scikit-image.measure.block_reduce`:
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.block_reduce
Args:
array: an array.
block_size: the size of the blocks on which the reduction is performed.
Must evenly divide `array.shape`.
reduction_fn: a reduction function that will be applied to each block of
size `block_size`.
Returns:
The result of applying `f` to each block of size `block_size`.
"""
new_shape = []
for b, s in zip(block_size, array.shape):
multiple, residual = divmod(s, b)
if residual != 0:
raise ValueError('`block_size` must divide `array.shape`;'
f'got {block_size}, {array.shape}.')
new_shape += [multiple, b]
multiple_axis_reduction_fn = reduction_fn
for j in reversed(range(array.ndim)):
multiple_axis_reduction_fn = jax.vmap(multiple_axis_reduction_fn, j)
return multiple_axis_reduction_fn(array.reshape(new_shape))
def laplacian_matrix(size: int, step: float) -> np.ndarray:
"""Create 1D Laplacian operator matrix, with periodic BC."""
column = np.zeros(size)
column[0] = -2 / step**2
column[1] = column[-1] = 1 / step**2
return scipy.linalg.circulant(column)
def _laplacian_boundary_dirichlet_cell_centered(laplacians: List[Array],
grid: grids.Grid, axis: int,
side: str) -> None:
"""Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc.
laplacians[i] contains a 3 point stencil matrix L that approximates
d^2/dx_i^2.
For detailed documentation on laplacians input type see
array_utils.laplacian_matrix.
The default return of array_utils.laplacian_matrix makes a matrix for
periodic boundary. For dirichlet boundary, the correct equation is
L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So
laplacian_boundary_dirichlet restricts the matrix L to
interior points only.
This function assumes RHS has cell-centered offset.
Args:
laplacians: list of 1d laplacians
grid: grid object
axis: axis along which to impose dirichlet bc.
side: lower or upper side to assign boundary to.
Returns:
updated list of 1d laplacians.
"""
# This function assumes homogeneous boundary, in which case if the offset
# is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
# 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
if side == 'lower':
laplacians[axis][0, 0] = laplacians[axis][0, 0] - 1 / grid.step[axis]**2
else:
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] - 1 / grid.step[axis]**2
# deletes corner dependencies on the "looped-around" part.
# this should be done irrespective of which side, since one boundary cannot
# be periodic while the other is.
laplacians[axis][0, -1] = 0.0
laplacians[axis][-1, 0] = 0.0
return
def _laplacian_boundary_neumann_cell_centered(laplacians: List[Array],
grid: grids.Grid, axis: int,
side: str) -> None:
"""Converts 1d laplacian matrix to satisfy neumann homogeneous bc.
This function assumes the RHS will have a cell-centered offset.
Neumann boundaries are not defined for edge-aligned offsets elsewhere in the
code.
Args:
laplacians: list of 1d laplacians
grid: grid object
axis: axis along which to impose dirichlet bc.
side: which boundary side to convert to neumann homogeneous bc.
Returns:
updated list of 1d laplacians.
"""
if side == 'lower':
laplacians[axis][0, 0] = laplacians[axis][0, 0] + 1 / grid.step[axis]**2
else:
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] + 1 / grid.step[axis]**2
# deletes corner dependencies on the "looped-around" part.
# this should be done irrespective of which side, since one boundary cannot
# be periodic while the other is.
laplacians[axis][0, -1] = 0.0
laplacians[axis][-1, 0] = 0.0
return
def laplacian_matrix_w_boundaries(
grid: grids.Grid,
offset: Tuple[float, ...],
bc: boundaries.BoundaryConditions,
) -> List[Array]:
"""Returns 1d laplacians that satisfy boundary conditions bc on grid.
Given grid, offset and boundary conditions, returns a list of 1 laplacians
(one along each axis).
Currently, only homogeneous or periodic boundary conditions are supported.
Args:
grid: The grid used to construct the laplacian.
offset: The offset of the variable on which laplacian acts.
bc: the boundary condition of the variable on which the laplacian acts.
Returns:
A list of 1d laplacians.
"""
if not isinstance(bc, boundaries.ConstantBoundaryConditions):
raise NotImplementedError(
f'Explicit laplacians are not implemented for {bc}.')
laplacians = list(map(laplacian_matrix, grid.shape, grid.step))
for axis in range(grid.ndim):
if np.isclose(offset[axis], 0.5):
for i, side in enumerate(['lower', 'upper']): # lower and upper boundary
if bc.types[axis][i] == boundaries.BCType.NEUMANN:
_laplacian_boundary_neumann_cell_centered(
laplacians, grid, axis, side)
elif bc.types[axis][i] == boundaries.BCType.DIRICHLET:
_laplacian_boundary_dirichlet_cell_centered(
laplacians, grid, axis, side)
if np.isclose(offset[axis] % 1, 0.):
if bc.types[axis][0] == boundaries.BCType.DIRICHLET and bc.types[
axis][1] == boundaries.BCType.DIRICHLET:
# This function assumes homogeneous boundary and acts on the interior.
# Thus, the laplacian can be cut off past the edge.
# The interior grid has one fewer grid cell than the actual grid, so
# the size of the laplacian should be reduced.
laplacians[axis] = laplacians[axis][:-1, :-1]
elif boundaries.BCType.NEUMANN in bc.types[axis]:
raise NotImplementedError(
'edge-aligned Neumann boundaries are not implemented.')
return laplacians
def unstack(array, axis):
"""Returns a tuple of slices of `array` along axis `axis`."""
squeeze_fn = lambda x: jnp.squeeze(x, axis=axis)
return tuple(squeeze_fn(x) for x in jnp.split(array, array.shape[axis], axis))
def gram_schmidt_qr(
matrix: Array,
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST
) -> Tuple[Array, Array]:
"""Computes QR decomposition using gramm-schmidt orthogonalization.
This algorithm is suitable for tall matrices with very few columns. This
method is more memory efficient compared to `jnp.linalg.qr`, but is less
numerically stable, especially for matrices with many columns.
Args:
matrix: 2D array representing the matrix to be decomposed into orthogonal
and upper triangular.
precision: numerical precision for matrix multplication. Only relevant on
TPUs.
Returns:
tuple of matrix Q whose columns are orthonormal and R that is upper
triangular.
"""
def orthogonalize(vector, others):
"""Returns the orthogonal component of `vector` with respect to `others`."""
if not others:
return vector / jnp.linalg.norm(vector)
orthogonalize_step = lambda c, x: tuple([c - jnp.dot(c, x) * x, None])
vector, _ = jax.lax.scan(orthogonalize_step, vector, jnp.stack(others))
return vector / jnp.linalg.norm(vector)
num_columns = matrix.shape[1]
columns = unstack(matrix, axis=1)
q_columns = []
r_rows = []
for vec_index, column in enumerate(columns):
next_q_column = orthogonalize(column, q_columns)
r_rows.append(jnp.asarray([
jnp.dot(columns[i], next_q_column) if i >= vec_index else 0.
for i in range(num_columns)]))
q_columns.append(next_q_column)
q = jnp.stack(q_columns, axis=1)
r = jnp.stack(r_rows)
# permute q columns to make entries of r on the diagonal positive.
d = jnp.diag(jnp.sign(jnp.diagonal(r)))
q = jnp.matmul(q, d, precision=precision)
r = jnp.matmul(d, r, precision=precision)
return q, r
def interp1d( # pytype: disable=annotation-type-mismatch # jnp-type
x: Array,
y: Array,
axis: int = -1,
fill_value: Union[str, Array] = jnp.nan,
assume_sorted: bool = True,
) -> Callable[[Array], jax.Array]:
"""Build an interpolation function to approximate `y = f(x)`.
x and y are arrays of values used to approximate some function f: y = f(x).
This returns a function that uses linear interpolation to approximate f
evaluated at new points.
```
x = jnp.linspace(0, 10)
y = jnp.sin(x)
f = interp1d(x, y)
x_new = 1.23
f(x_new)
==> Approximately sin(1.23).
x_new = ... # Shape (4, 5) array
f(x_new)
==> Shape (4, 5) array, approximating sin(x_new).
```
Args:
x: Length N 1-D array of values.
y: Shape (..., N, ...) array of values corresponding to f(x).
axis: Specifies the axis of y along which to interpolate. Interpolation
defaults to the last axis of y.
fill_value: Scalar array or string. If array, this value will be used to
fill in for requested points outside of the data range. If not provided,
then the default is NaN. If "extrapolate", then linear extrapolation is
used. If "constant_extrapolate", then the function is extended as a
constant.
assume_sorted: Whether to assume x is sorted. If True, x must be
monotonically increasing. If False, this function sorts x and reorders
y appropriately.
Returns:
Callable mapping array x_new to values y_new, where
y_new.shape = y.shape[:axis] + x_new.shape + y.shape[axis + 1:]
"""
allowed_fill_value_strs = {'constant_extrapolate', 'extrapolate'}
if isinstance(fill_value, str):
if fill_value not in allowed_fill_value_strs:
raise ValueError(
f'`fill_value` "{fill_value}" not in {allowed_fill_value_strs}')
else:
fill_value = np.asarray(fill_value)
if fill_value.ndim > 0:
raise ValueError(f'Only scalar `fill_value` supported. '
f'Found shape: {fill_value.shape}')
x = jnp.asarray(x)
if x.ndim != 1:
raise ValueError(f'Expected `x` to be 1D. Found shape {x.shape}')
if x.size < 2:
raise ValueError(f'`x` must have at least 2 entries. Found shape {x.shape}')
n_x = x.shape[0]
if not assume_sorted:
ind = jnp.argsort(x)
x = x[ind]
y = jnp.take(y, ind, axis=axis)
y = jnp.asarray(y)
if y.ndim < 1:
raise ValueError(f'Expected `y` to have rank >= 1. Found shape {y.shape}')
if x.size != y.shape[axis]:
raise ValueError(
f'x and y arrays must be equal in length along interpolation axis. '
f'Found x.shape={x.shape} and y.shape={y.shape} and axis={axis}')
axis = _normalize_axis(axis, ndim=y.ndim)
def interp_func(x_new: jax.Array) -> jax.Array:
"""Implementation of the interpolation function."""
x_new = jnp.asarray(x_new)
# We will flatten x_new, then interpolate, then reshape the output.
x_new_shape = x_new.shape
x_new = jnp.ravel(x_new)
# This construction of indices ensures that below_idx < above_idx, even at
# x_new = x[i] exactly for some i.
x_new_clipped = jnp.clip(x_new, np.min(x), np.max(x))
above_idx = jnp.minimum(n_x - 1,
jnp.searchsorted(x, x_new_clipped, side='right'))
below_idx = jnp.maximum(0, above_idx - 1)
def expand(array):
"""Add singletons to rightmost dims of `array` so it bcasts with y."""
array = jnp.asarray(array)
return jnp.reshape(array, array.shape + (1,) * (y.ndim - axis - 1))
x_above = jnp.take(x, above_idx)
x_below = jnp.take(x, below_idx)
y_above = jnp.take(y, above_idx, axis=axis)
y_below = jnp.take(y, below_idx, axis=axis)
slope = (y_above - y_below) / expand(x_above - x_below)
if isinstance(fill_value, str) and fill_value == 'extrapolate':
delta_x = expand(x_new - x_below)
y_new = y_below + delta_x * slope
elif isinstance(fill_value, str) and fill_value == 'constant_extrapolate':
delta_x = expand(x_new_clipped - x_below)
y_new = y_below + delta_x * slope
else: # Else fill_value is an Array.
delta_x = expand(x_new - x_below)
fill_value_ = expand(fill_value)
y_new = y_below + delta_x * slope
y_new = jnp.where(
(delta_x < 0) | (delta_x > expand(x_above - x_below)),
fill_value_, y_new)
return jnp.reshape(
y_new, y_new.shape[:axis] + x_new_shape + y_new.shape[axis + 1:])
return interp_func
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of array utils."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
import scipy.interpolate as spi
import skimage.measure as skm
BCType = boundaries.BCType
class ArrayUtilsTest(test_util.TestCase):
@parameterized.parameters(
dict(array=np.random.RandomState(1234).randn(3, 6, 9),
block_size=(3, 3, 3),
f=jnp.mean),
dict(array=np.random.RandomState(1234).randn(12, 24, 36),
block_size=(6, 6, 6),
f=jnp.max),
dict(array=np.random.RandomState(1234).randn(12, 24, 36),
block_size=(3, 4, 6),
f=jnp.min),
)
def test_block_reduce(self, array, block_size, f):
"""Test `block_reduce` is equivalent to `skimage.measure.block_reduce`."""
expected_output = skm.block_reduce(array, block_size, f)
actual_output = array_utils.block_reduce(array, block_size, f)
self.assertAllClose(expected_output, actual_output, atol=1e-6)
def test_laplacian_matrix(self):
actual = array_utils.laplacian_matrix(4, step=0.5)
expected = 4.0 * np.array(
[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1, -2]])
self.assertAllClose(expected, actual)
@parameterized.parameters(
# Periodic BC
dict(
offset=(0,),
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1,
-2]]),
dict(
offset=(0.5,),
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1,
-2]]),
dict(
offset=(1.,),
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1,
-2]]),
# Dirichlet BC
dict(
offset=(0,),
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
expected=[[-2, 1, 0], [1, -2, 1], [0, 1, -2]]),
dict(
offset=(0.5,),
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
expected=[[-3, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1,
-3]]),
dict(
offset=(1.,),
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
expected=[[-2, 1, 0], [1, -2, 1], [0, 1, -2]]),
# Neumann BC
dict(
offset=(0.5,),
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
expected=[[-1, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1,
-1]]),
# Neumann-Dirichlet BC
dict(
offset=(0.5,),
bc_types=((BCType.NEUMANN, BCType.DIRICHLET),),
expected=[[-1, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1,
-3]]),
)
def test_laplacian_matrix_w_boundaries(self, offset, bc_types, expected):
grid = grids.Grid((4,), step=(.5,))
bc = boundaries.HomogeneousBoundaryConditions(bc_types)
actual = array_utils.laplacian_matrix_w_boundaries(grid, offset, bc)
actual = np.squeeze(actual)
expected = 4.0 * np.array(expected)
self.assertAllClose(expected, actual)
@parameterized.parameters(
dict(matrix=(np.random.RandomState(1234).randn(16, 2))),
dict(matrix=(np.random.RandomState(1234).randn(24, 1))),
dict(matrix=(np.random.RandomState(1234).randn(74, 4))),
)
def test_gram_schmidt_qr(self, matrix):
"""Tests that gram-schmidt_qr is close to numpy for slim matrices."""
q_actual, r_actual = array_utils.gram_schmidt_qr(matrix)
q, r = jnp.linalg.qr(matrix)
# we rearrange the result to make the diagonal of `r` positive.
d = jnp.diag(jnp.sign(jnp.diagonal(r)))
q_expected = q @ d
r_expected = d @ r
self.assertAllClose(q_expected, q_actual, atol=1e-4)
self.assertAllClose(r_expected, r_actual, atol=1e-4)
@parameterized.parameters(
dict(pytree=(np.zeros((6, 3)), np.ones((6, 2, 2))), idx=3, axis=0),
dict(pytree=(np.zeros((3, 8)), np.ones((6, 8))), idx=3, axis=-1),
dict(pytree={'a': np.zeros((3, 9)), 'b': np.ones((6, 9))}, idx=3, axis=1),
dict(pytree=np.zeros((13, 5)), idx=6, axis=0),
dict(pytree=(np.zeros(9), (np.ones((9, 1)), np.ones(9))), idx=6, axis=0),
)
def test_split_and_concat(self, pytree, idx, axis):
"""Tests that split_along_axis, concat_along_axis return expected shapes."""
split_a, split_b = array_utils.split_along_axis(pytree, idx, axis, False)
with self.subTest('split_shape'):
self.assertEqual(jax.tree_util.leaves(split_a)[0].shape[axis], idx)
reconstruction = array_utils.concat_along_axis([split_a, split_b], axis)
with self.subTest('split_concat_roundtrip_structure'):
actual_tree_def = jax.tree_util.structure(reconstruction)
expected_tree_def = jax.tree_util.structure(pytree)
self.assertSameStructure(actual_tree_def, expected_tree_def)
actual_values = jax.tree_util.leaves(reconstruction)
expected_values = jax.tree_util.leaves(pytree)
with self.subTest('split_concat_roundtrip_values'):
for actual, expected in zip(actual_values, expected_values):
self.assertAllClose(actual, expected)
same_ndims = len(set(a.ndim for a in actual_values)) == 1
if not same_ndims:
with self.subTest('raises_when_wrong_ndims'):
with self.assertRaisesRegex(ValueError, 'arrays in `inputs` expected'):
split_a, split_b = array_utils.split_along_axis(pytree, idx, axis)
with self.subTest('multiple_concat_shape'):
arrays = [split_a, split_a, split_b, split_b]
double_concat = array_utils.concat_along_axis(arrays, axis)
actual_shape = jax.tree_util.leaves(double_concat)[0].shape[axis]
expected_shape = jax.tree_util.leaves(pytree)[0].shape[axis] * 2
self.assertEqual(actual_shape, expected_shape)
@parameterized.parameters(
dict(pytree=(np.zeros((6, 3)), np.ones((6, 2, 2))), axis=0),
dict(pytree={'a': np.zeros((3, 9)), 'b': np.ones((6, 9))}, axis=1),
dict(pytree=np.zeros((13, 5)), axis=0),
dict(pytree=(np.zeros(9), (np.ones((9, 1)), np.ones(9))), axis=0),
)
def test_split_along_axis_shapes(self, pytree, axis):
with self.subTest('with_keep_dims'):
splits = array_utils.split_axis(pytree, axis, keep_dims=True)
get_expected_shape = lambda x: x.shape[:axis] + (1,) + x.shape[axis + 1:]
expected_shapes = jax.tree_util.tree_map(get_expected_shape, pytree)
for split in splits:
actual = jax.tree_util.tree_map(lambda x: x.shape, split)
self.assertEqual(expected_shapes, actual)
with self.subTest('without_keep_dims'):
splits = array_utils.split_axis(pytree, axis, keep_dims=False)
get_expected_shape = lambda x: x.shape[:axis] + x.shape[axis + 1:]
expected_shapes = jax.tree_util.tree_map(get_expected_shape, pytree)
for split in splits:
actual = jax.tree_util.tree_map(lambda x: x.shape, split)
self.assertEqual(expected_shapes, actual)
class Interp1DTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='FillWithExtrapolate', fill_value='extrapolate'),
dict(testcase_name='FillWithScalar', fill_value=123),
)
def test_same_as_scipy_on_scalars_and_check_grads(self, fill_value):
rng = np.random.RandomState(45)
n = 10
y = rng.randn(n)
x_low = 5.
x_high = 9.
x = np.linspace(x_low, x_high, num=n)
sp_func = spi.interp1d(
x, y, kind='linear', fill_value=fill_value, bounds_error=False)
cfd_func = array_utils.interp1d(x, y, fill_value=fill_value)
# Check x_new at the table definition points `x`, points outside, and points
# in between.
x_to_check = np.concatenate(([x_low - 1], x, x * 1.051, [x_high + 1]))
for x_new in x_to_check:
sp_y_new = sp_func(x_new).astype(np.float32)
cfd_y_new = cfd_func(x_new)
self.assertAllClose(sp_y_new, cfd_y_new, rtol=1e-5, atol=1e-6)
grad_cfd_y_new = jax.grad(cfd_func)(x_new)
# Gradients should be nonzero except when outside the range of `x` and
# filling with a constant.
# Why check this? Because, some indexing methods result in gradients == 0
# at the interpolation table points.
if fill_value == 'extrapolate' or x_low <= x_new <= x_high:
self.assertLess(0, np.abs(grad_cfd_y_new))
else:
self.assertTrue(np.isfinite(grad_cfd_y_new))
@parameterized.named_parameters(
dict(
testcase_name='TableIs1D_XNewIs1D_Axis0_FillWithScalar',
table_ndim=1,
x_new_ndim=1,
axis=0,
fill_value=12345,
),
dict(
testcase_name='TableIs1D_XNewIs1D_Axis0_FillWithScalar_NoAssumeSorted',
table_ndim=1,
x_new_ndim=1,
axis=0,
fill_value=12345,
assume_sorted=False,
),
dict(
testcase_name='TableIs1D_XNewIs1D_Axis0_FillWithExtrapolate',
table_ndim=1,
x_new_ndim=1,
axis=0,
fill_value='extrapolate',
),
dict(
testcase_name='TableIs2D_XNewIs1D_Axis0_FillWithExtrapolate',
table_ndim=2,
x_new_ndim=1,
axis=0,
fill_value='extrapolate',
),
dict(
testcase_name=(
'TableIs2D_XNewIs1D_Axis0_FillWithExtrapolate_NoAssumeSorted'),
table_ndim=2,
x_new_ndim=1,
axis=0,
fill_value='extrapolate',
assume_sorted=False,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axis1_FillWithScalar',
table_ndim=3,
x_new_ndim=2,
axis=1,
fill_value=12345,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn1_FillWithScalar',
table_ndim=3,
x_new_ndim=2,
axis=-1,
fill_value=12345,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn1_FillWithExtrapolate',
table_ndim=3,
x_new_ndim=2,
axis=-1,
fill_value='extrapolate',
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn3_FillWithScalar',
table_ndim=3,
x_new_ndim=2,
axis=-3,
fill_value=1234,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn3_FillWithConstantExtrapolate',
table_ndim=3,
x_new_ndim=2,
axis=-3,
fill_value='constant_extrapolate',
),
dict(
testcase_name=(
'Table3D_XNew2D_Axisn3_FillConstantExtrapolate_NoAssumeSorted'),
table_ndim=3,
x_new_ndim=2,
axis=-3,
fill_value='constant_extrapolate',
assume_sorted=False,
),
)
def test_same_as_scipy_on_arrays(
self,
table_ndim,
x_new_ndim,
axis,
fill_value,
assume_sorted=True,
):
"""Test results are the same as scipy.interpolate.interp1d."""
rng = np.random.RandomState(45)
# Arbitrary shape that ensures all dims are different to prevent
# broadcasting from hiding bugs.
y_shape = tuple(range(5, 5 + table_ndim))
y = rng.randn(*y_shape)
n = y_shape[axis]
x_low = 5
x_high = 9
x = np.linspace(x_low, x_high, num=n)**2 # Arbitrary non-linearly spaced x
if not assume_sorted:
rng.shuffle(x)
# Scipy doesn't have 'constant_extrapolate', so treat it special.
# Here use np.nan as the fill value, which will be easy to spot if we handle
# it wrong.
if fill_value == 'constant_extrapolate':
sp_fill_value = np.nan
else:
sp_fill_value = fill_value
sp_func = spi.interp1d(
x,
y,
kind='linear',
axis=axis,
fill_value=sp_fill_value,
bounds_error=False,
assume_sorted=assume_sorted,
)
cfd_func = array_utils.interp1d(
x, y, axis=axis, fill_value=fill_value, assume_sorted=assume_sorted)
# Make n_x_new > n, so we can selectively fill values as below.
n_x_new = max(2 * n, 20)
# Make x_new over the same range as x.
x_new_shape = tuple(range(2, x_new_ndim + 1)) + (n_x_new,)
x_new = (x_low + rng.rand(*x_new_shape) * (x_high - x_low))**2
x_new[..., 0] = np.min(x) - 1 # Out of bounds low
x_new[..., -1] = np.max(x) + 1 # Out of bounds high
x_new[..., 1:n + 1] = x # All the grid points
# Scipy doesn't have the 'constant_extrapolate' feature, but
# constant_extrapolate is achieved by clipping the input.
if fill_value == 'constant_extrapolate':
sp_x_new = np.clip(x_new, np.min(x), np.max(x))
else:
sp_x_new = x_new
sp_y_new = sp_func(sp_x_new).astype(np.float32)
cfd_y_new = cfd_func(x_new)
self.assertAllClose(sp_y_new, cfd_y_new, rtol=1e-6, atol=1e-6)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes that specify how boundary conditions are applied to arrays."""
import dataclasses
import math
from typing import Optional, Sequence, Tuple, Union
import jax
from jax import lax
import jax.numpy as jnp
from jax_cfd.base import grids
import numpy as np
BoundaryConditions = grids.BoundaryConditions
GridArray = grids.GridArray
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
Array = Union[np.ndarray, jax.Array]
class BCType:
PERIODIC = 'periodic'
DIRICHLET = 'dirichlet'
NEUMANN = 'neumann'
class Padding:
MIRROR = 'mirror'
EXTEND = 'extend'
@dataclasses.dataclass(init=False, frozen=True)
class ConstantBoundaryConditions(BoundaryConditions):
"""Boundary conditions for a PDE variable that are constant in space and time.
Example usage:
grid = Grid((10, 10))
array = GridArray(np.zeros((10, 10)), offset=(0.5, 0.5), grid)
bc = ConstantBoundaryConditions(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)),
((0.0, 10.0),(1.0, 0.0)))
u = GridVariable(array, bc)
Attributes:
types: `types[i]` is a tuple specifying the lower and upper BC types for
dimension `i`.
"""
types: Tuple[Tuple[str, str], ...]
bc_values: Tuple[Tuple[Optional[float], Optional[float]], ...]
def __init__(self, types: Sequence[Tuple[str, str]],
values: Sequence[Tuple[Optional[float], Optional[float]]]):
types = tuple(types)
values = tuple(values)
object.__setattr__(self, 'types', types)
object.__setattr__(self, 'bc_values', values)
def shift(
self,
u: GridArray,
offset: int,
axis: int,
mode: Optional[str] = Padding.EXTEND,
) -> GridArray:
"""Shift an GridArray by `offset`.
Args:
u: an `GridArray` object.
offset: positive or negative integer offset to shift.
axis: axis to shift along.
mode: type of padding to use in non-periodic case.
Mirror mirrors the flow across the boundary.
Extend extends the last well-defined value past the boundary.
Returns:
A copy of `u`, shifted by `offset`. The returned `GridArray` has offset
`u.offset + offset`.
"""
padded = self._pad(u, offset, axis, mode=mode)
trimmed = self._trim(padded, -offset, axis)
return trimmed
def _is_aligned(self, u: GridArray, axis: int) -> bool:
"""Checks if array u contains all interior domain information.
For dirichlet edge aligned boundary, the value that lies exactly on the
boundary does not have to be specified by u.
Neumann edge aligned boundary is not defined.
Args:
u: array that should contain interior data
axis: axis along which to check
Returns:
True if u is aligned, and raises error otherwise.
"""
size_diff = u.shape[axis] - u.grid.shape[axis]
if self.types[axis][0] == BCType.DIRICHLET and np.isclose(
u.offset[axis], 1):
size_diff += 1
if self.types[axis][1] == BCType.DIRICHLET and np.isclose(
u.offset[axis], 1):
size_diff += 1
if self.types[axis][0] == BCType.NEUMANN and np.isclose(
u.offset[axis] % 1, 0):
raise NotImplementedError('Edge-aligned neumann BC are not implemented.')
if size_diff < 0:
raise ValueError(
'the GridArray does not contain all interior grid values.')
return True
def _pad(
self,
u: GridArray,
width: int,
axis: int,
mode: Optional[str] = Padding.EXTEND,
) -> GridArray:
"""Pad a GridArray.
For dirichlet boundary, u is mirrored.
Important: For jax_cfd finite difference/finite-volume code, no more than 1
ghost cell is required. More ghost cells are used only in LES filtering/CNN
application.
Args:
u: a `GridArray` object.
width: number of elements to pad along axis. Use negative value for lower
boundary or positive value for upper boundary.
axis: axis to pad along.
mode: type of padding to use in non-periodic case.
Mirror mirrors the array values across the boundary.
Extend extends the last well-defined array value past the boundary.
Mode is only needed if the padding extends past array values that are
defined by the physics. In these cases, no mode is necessary. This
also means periodic boundaries do not require a mode and can use
mode=None.
Returns:
Padded array, elongated along the indicated axis.
"""
def make_padding(width):
if width < 0: # pad lower boundary
bc_type = self.types[axis][0]
padding = (-width, 0)
else: # pad upper boundary
bc_type = self.types[axis][1]
padding = (0, width)
full_padding = [(0, 0)] * u.grid.ndim
full_padding[axis] = padding
return full_padding, padding, bc_type
full_padding, padding, bc_type = make_padding(width)
offset = list(u.offset)
offset[axis] -= padding[0]
if bc_type == BCType.PERIODIC:
need_trimming = 'both' # need to trim both sides
elif width >= 0:
need_trimming = 'right' # only one side needs to be trimmed
else:
need_trimming = 'left' # only one side needs to be trimmed
u, trimmed_padding = self._trim_padding(u, axis, need_trimming)
data = u.data
full_padding[axis] = tuple(
pad + trimmed_pad
for pad, trimmed_pad in zip(full_padding[axis], trimmed_padding))
if bc_type == BCType.PERIODIC:
# for periodic, all grid points must be there. Otherwise padding doesn't
# make sense.
# self.values are ignored here
pad_kwargs = dict(mode='wrap')
data = jnp.pad(data, full_padding, **pad_kwargs)
elif bc_type == BCType.DIRICHLET:
if np.isclose(u.offset[axis] % 1, 0.5): # cell center
# If only one or 0 value is needed, no mode is necessary.
# All modes would return the same values.
if np.isclose(sum(full_padding[axis]), 1) or np.isclose(
sum(full_padding[axis]), 0):
mode = Padding.MIRROR
if mode == Padding.MIRROR:
# make the linearly interpolated value equal to the boundary by
# setting the padded values to the negative symmetric values
data = (2 * jnp.pad(
data,
full_padding,
mode='constant',
constant_values=self.bc_values) -
jnp.pad(data, full_padding, mode='symmetric'))
elif mode == Padding.EXTEND:
# computes the well-defined ghost cell and sets the rest of padding
# values equal to the ghost cell.
data = (2 * jnp.pad(
data,
full_padding,
mode='constant',
constant_values=self.bc_values) -
jnp.pad(data, full_padding, mode='edge'))
else:
raise NotImplementedError(f'Mode {mode} is not implemented yet.')
elif np.isclose(u.offset[axis] % 1, 0): # cell edge
# u specifies the values on the interior CV. Thus, first the value on
# the boundary needs to be added to the array, if not specified by the
# interior CV values.
# Then the mirrored ghost cells need to be appended.
# if only one value is needed, no mode is necessary.
if np.isclose(sum(full_padding[axis]), 1) or np.isclose(
sum(full_padding[axis]), 0):
data = jnp.pad(
data,
full_padding,
mode='constant',
constant_values=self.bc_values)
elif sum(full_padding[axis]) > 1:
if mode == Padding.MIRROR:
# make boundary-only padding
bc_padding = [(0, 0)] * u.grid.ndim
bc_padding[axis] = tuple(
1 if pad > 0 else 0 for pad in full_padding[axis])
# subtract the padded cell
full_padding_past_bc = [(0, 0)] * u.grid.ndim
full_padding_past_bc[axis] = tuple(
pad - 1 if pad > 0 else 0 for pad in full_padding[axis])
# here we are adding 0 boundary cell with 0 value
expanded_data = jnp.pad(
data, bc_padding, mode='constant', constant_values=(0, 0))
padding_values = list(self.bc_values)
padding_values[axis] = [pad / 2 for pad in padding_values[axis]]
data = 2 * jnp.pad(
data,
full_padding,
mode='constant',
constant_values=tuple(padding_values)) - jnp.pad(
expanded_data, full_padding_past_bc, mode='reflect')
elif mode == Padding.EXTEND:
data = jnp.pad(
data,
full_padding,
mode='constant',
constant_values=self.bc_values)
else:
raise NotImplementedError(f'Mode {mode} is not implemented yet.')
else:
raise ValueError('expected offset to be an edge or cell center, got '
f'offset[axis]={u.offset[axis]}')
elif bc_type == BCType.NEUMANN:
if not np.isclose(u.offset[axis] % 1, 0.5):
raise ValueError(
'expected offset to be cell center for neumann bc, got '
f'offset[axis]={u.offset[axis]}')
else:
# When the data is cell-centered, computes the backward difference.
# if only one value is needed, no mode is necessary. Default mode is
# provided, although all modes would return the same values.
if np.isclose(sum(full_padding[axis]), 1) or np.isclose(
sum(full_padding[axis]), 0):
np_mode = 'symmetric'
elif mode == Padding.MIRROR:
np_mode = 'symmetric'
elif mode == Padding.EXTEND:
np_mode = 'edge'
else:
raise NotImplementedError(f'Mode {mode} is not implemented yet.')
# ensures that finite_differences.backward_difference satisfies the
# boundary condition.
derivative_direction = float(width // max(1, abs(width)))
data = (
jnp.pad(data, full_padding, mode=np_mode) -
derivative_direction * u.grid.step[axis] *
(jnp.pad(data, full_padding, mode='constant') - jnp.pad(
data,
full_padding,
mode='constant',
constant_values=self.bc_values)))
else:
raise ValueError('invalid boundary type')
return GridArray(data, tuple(offset), u.grid)
def _trim(
self,
u: GridArray,
width: int,
axis: int,
) -> GridArray:
"""Trim padding from a GridArray.
Args:
u: a `GridArray` object.
width: number of elements to trim along axis. Use negative value for lower
boundary or positive value for upper boundary.
axis: axis to trim along.
Returns:
Trimmed array, shrunk along the indicated axis.
"""
if width < 0: # trim lower boundary
padding = (-width, 0)
else: # trim upper boundary
padding = (0, width)
limit_index = u.data.shape[axis] - padding[1]
data = lax.slice_in_dim(u.data, padding[0], limit_index, axis=axis)
offset = list(u.offset)
offset[axis] += padding[0]
return GridArray(data, tuple(offset), u.grid)
def _trim_padding(self,
u: grids.GridArray,
axis: int = 0,
trim_side: str = 'both'):
"""Trims padding from a GridArray along axis and returns the array interior.
Args:
u: a `GridArray` object.
axis: axis to trim along.
trim_side: if 'both', trims both sides. If 'right', trims the right side.
If 'left', the left side.
Returns:
Trimmed array, shrunk along the indicated axis side.
"""
padding = (0, 0)
if u.shape[axis] >= u.grid.shape[axis]:
# number of cells that were padded on the left
negative_trim = 0
if u.offset[axis] <= 0 and (trim_side == 'both' or trim_side == 'left'):
negative_trim = -math.ceil(-u.offset[axis])
# periodic is a special case. Shifted data might still contain all the
# information.
if self.types[axis][0] == BCType.PERIODIC:
negative_trim = max(negative_trim, u.grid.shape[axis] - u.shape[axis])
# for both DIRICHLET and NEUMANN cases the value on grid.domain[0] is
# a dependent value.
elif np.isclose(u.offset[axis] % 1, 0):
negative_trim -= 1
u = self._trim(u, negative_trim, axis)
# number of cells that were padded on the right
positive_trim = 0
if (trim_side == 'right' or trim_side == 'both'):
# periodic is a special case. Boundary on one side depends on the other
# side.
if self.types[axis][1] == BCType.PERIODIC:
positive_trim = max(u.shape[axis] - u.grid.shape[axis], 0)
else:
# for other cases, where to trim depends only on the boundary type
# and data offset.
last_u_offset = u.shape[axis] + u.offset[axis] - 1
boundary_offset = u.grid.shape[axis]
if last_u_offset >= boundary_offset:
positive_trim = math.ceil(last_u_offset - boundary_offset)
if self.types[axis][1] == BCType.DIRICHLET and np.isclose(
u.offset[axis] % 1, 0):
positive_trim += 1
if positive_trim > 0:
u = self._trim(u, positive_trim, axis)
# combining existing padding with new padding
padding = (-negative_trim, positive_trim)
return u, padding
def pad(self,
u: GridArray,
width: Union[Tuple[int, int], int],
axis: int,
mode: Optional[str] = Padding.EXTEND,
) -> GridArray:
"""Wrapper for _pad.
Args:
u: a `GridArray` object.
width: number of elements to pad along axis. If width is an int, use
negative value for lower boundary or positive value for upper boundary.
If a tuple, pads with width[0] on the left and width[1] on the right.
axis: axis to pad along.
mode: type of padding to use in non-periodic case.
Mirror mirrors the array values across the boundary.
Extend extends the last well-defined array value past the boundary.
Returns:
Padded array, elongated along the indicated axis.
"""
_ = self._is_aligned(u, axis)
if isinstance(width, int):
u = self._pad(u, width, axis, mode=mode)
else:
u = self._pad(u, -width[0], axis, mode=mode)
u = self._pad(u, width[1], axis, mode=mode)
return u
def pad_all(self,
u: GridArray,
width: Tuple[Tuple[int, int], ...],
mode: Optional[str] = Padding.EXTEND
) -> GridArray:
"""Pads along all axes with pad width specified by width tuple.
Args:
u: a `GridArray` object.
width: Tuple of padding width for each side for each axis.
mode: type of padding to use in non-periodic case.
Mirror mirrors the array values across the boundary.
Extend extends the last well-defined array value past the boundary.
Returns:
Padded array, elongated along all axes.
"""
for axis in range(u.grid.ndim):
u = self.pad(u, width[axis], axis, mode=mode)
return u
def values( # pytype: disable=signature-mismatch # overriding-parameter-count-checks
self, axis: int,
grid: grids.Grid) -> Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]:
"""Returns boundary values on the grid along axis.
Args:
axis: axis along which to return boundary values.
grid: a `Grid` object on which to evaluate boundary conditions.
Returns:
A tuple of arrays of grid.ndim - 1 dimensions that specify values on the
boundary. In case of periodic boundaries, returns a tuple(None,None).
"""
if None in self.bc_values[axis]:
return (None, None)
bc_values = tuple(
jnp.full(grid.shape[:axis] +
grid.shape[axis + 1:], self.bc_values[axis][-i])
for i in [0, 1])
return bc_values
def trim_boundary(self, u: grids.GridArray) -> grids.GridArray:
"""Returns GridArray without the grid points on the boundary.
Some grid points of GridArray might coincide with boundary. This trims those
values. If the array was padded beforehand, removes the padding.
Args:
u: a `GridArray` object.
Returns:
A GridArray shrunk along certain dimensions.
"""
for axis in range(u.grid.ndim):
_ = self._is_aligned(u, axis)
u, _ = self._trim_padding(u, axis)
return u
def pad_and_impose_bc(
self,
u: grids.GridArray,
offset_to_pad_to: Optional[Tuple[float,...]] = None,
mode: Optional[str] = Padding.EXTEND,
) -> grids.GridVariable:
"""Returns GridVariable with correct boundary values.
Some grid points of GridArray might coincide with boundary. This ensures
that the GridVariable.array agrees with GridVariable.bc.
Args:
u: a `GridArray` object that specifies only scalar values on the internal
nodes.
offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the
function is given just an interior array in dirichlet case, it can pad
to both 0 offset and 1 offset.
mode: type of padding to use in non-periodic case.
Mirror mirrors the flow across the boundary.
Extend extends the last well-defined value past the boundary.
Returns:
A GridVariable that has correct boundary values.
"""
if offset_to_pad_to is None:
offset_to_pad_to = u.offset
for axis in range(u.grid.ndim):
_ = self._is_aligned(u, axis)
if self.types[axis][0] == BCType.DIRICHLET and np.isclose(
u.offset[axis], 1.0):
if np.isclose(offset_to_pad_to[axis], 1.0):
u = self._pad(u, 1, axis, mode=mode)
elif np.isclose(offset_to_pad_to[axis], 0.0):
u = self._pad(u, -1, axis, mode=mode)
return grids.GridVariable(u, self)
def impose_bc(self, u: grids.GridArray) -> grids.GridVariable:
"""Returns GridVariable with correct boundary condition.
Some grid points of GridArray might coincide with boundary. This ensures
that the GridVariable.array agrees with GridVariable.bc.
Args:
u: a `GridArray` object.
Returns:
A GridVariable that has correct boundary values and is restricted to the
domain.
"""
offset = u.offset
u = self.trim_boundary(u)
return self.pad_and_impose_bc(u, offset)
trim = _trim
class HomogeneousBoundaryConditions(ConstantBoundaryConditions):
"""Boundary conditions for a PDE variable.
Example usage:
grid = Grid((10, 10))
array = GridArray(np.zeros((10, 10)), offset=(0.5, 0.5), grid)
bc = ConstantBoundaryConditions(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)))
u = GridVariable(array, bc)
Attributes:
types: `types[i]` is a tuple specifying the lower and upper BC types for
dimension `i`.
"""
def __init__(self, types: Sequence[Tuple[str, str]]):
ndim = len(types)
values = ((0.0, 0.0),) * ndim
super(HomogeneousBoundaryConditions, self).__init__(types, values)
# Convenience utilities to ease updating of BoundaryConditions implementation
def periodic_boundary_conditions(ndim: int) -> ConstantBoundaryConditions:
"""Returns periodic BCs for a variable with `ndim` spatial dimension."""
return HomogeneousBoundaryConditions(
((BCType.PERIODIC, BCType.PERIODIC),) * ndim)
def dirichlet_boundary_conditions(
ndim: int,
bc_vals: Optional[Sequence[Tuple[float, float]]] = None,
) -> ConstantBoundaryConditions:
"""Returns Dirichelt BCs for a variable with `ndim` spatial dimension.
Args:
ndim: spatial dimension.
bc_vals: A tuple of lower and upper boundary values for each dimension.
If None, returns Homogeneous BC.
Returns:
BoundaryCondition instance.
"""
if not bc_vals:
return HomogeneousBoundaryConditions(
((BCType.DIRICHLET, BCType.DIRICHLET),) * ndim)
else:
return ConstantBoundaryConditions(
((BCType.DIRICHLET, BCType.DIRICHLET),) * ndim, bc_vals)
def neumann_boundary_conditions(
ndim: int,
bc_vals: Optional[Sequence[Tuple[float, float]]] = None,
) -> ConstantBoundaryConditions:
"""Returns Neumann BCs for a variable with `ndim` spatial dimension.
Args:
ndim: spatial dimension.
bc_vals: A tuple of lower and upper boundary values for each dimension.
If None, returns Homogeneous BC.
Returns:
BoundaryCondition instance.
"""
if not bc_vals:
return HomogeneousBoundaryConditions(
((BCType.NEUMANN, BCType.NEUMANN),) * ndim)
else:
return ConstantBoundaryConditions(
((BCType.NEUMANN, BCType.NEUMANN),) * ndim, bc_vals)
def channel_flow_boundary_conditions(
ndim: int,
bc_vals: Optional[Sequence[Tuple[float, float]]] = None,
) -> ConstantBoundaryConditions:
"""Returns BCs periodic for dimension 0 and Dirichlet for dimension 1.
Args:
ndim: spatial dimension.
bc_vals: A tuple of lower and upper boundary values for each dimension.
If None, returns Homogeneous BC. For periodic dimensions the lower, upper
boundary values should be (None, None).
Returns:
BoundaryCondition instance.
"""
bc_type = ((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET))
for _ in range(ndim - 2):
bc_type += ((BCType.PERIODIC, BCType.PERIODIC),)
if not bc_vals:
return HomogeneousBoundaryConditions(bc_type)
else:
return ConstantBoundaryConditions(bc_type, bc_vals)
def periodic_and_neumann_boundary_conditions(
bc_vals: Optional[Tuple[float,
float]] = None) -> ConstantBoundaryConditions:
"""Returns BCs periodic for dimension 0 and Neumann for dimension 1.
Args:
bc_vals: the lower and upper boundary condition value for each dimension. If
None, returns Homogeneous BC.
Returns:
BoundaryCondition instance.
"""
if not bc_vals:
return HomogeneousBoundaryConditions(
((BCType.PERIODIC, BCType.PERIODIC), (BCType.NEUMANN, BCType.NEUMANN)))
else:
return ConstantBoundaryConditions(
((BCType.PERIODIC, BCType.PERIODIC), (BCType.NEUMANN, BCType.NEUMANN)),
((None, None), bc_vals))
def periodic_and_dirichlet_boundary_conditions(
bc_vals: Optional[Tuple[float, float]] = None,
periodic_axis=0) -> ConstantBoundaryConditions:
"""Returns BCs periodic for dimension 0 and Dirichlet for dimension 1.
Args:
bc_vals: the lower and upper boundary condition value for each dimension. If
None, returns Homogeneous BC.
periodic_axis: specifies which axis is periodic.
Returns:
BoundaryCondition subclass.
"""
periodic = (BCType.PERIODIC, BCType.PERIODIC)
dirichlet = (BCType.DIRICHLET, BCType.DIRICHLET)
if periodic_axis == 0:
if not bc_vals:
return HomogeneousBoundaryConditions((periodic, dirichlet))
else:
return ConstantBoundaryConditions((periodic, dirichlet),
((None, None), bc_vals))
else:
if not bc_vals:
return HomogeneousBoundaryConditions((dirichlet, periodic))
else:
return ConstantBoundaryConditions((dirichlet, periodic),
(bc_vals, (None, None)))
def is_periodic_boundary_conditions(c: grids.GridVariable, axis: int) -> bool:
"""Returns true if scalar has periodic bc along axis."""
if c.bc.types[axis][0] != BCType.PERIODIC:
return False
return True
def has_all_periodic_boundary_conditions(*arrays: GridVariable) -> bool:
"""Returns True if arrays have periodic BC in every dimension, else False."""
for array in arrays:
for axis in range(array.grid.ndim):
if not is_periodic_boundary_conditions(array, axis):
return False
return True
def consistent_boundary_conditions(*arrays: GridVariable) -> Tuple[str, ...]:
"""Returns whether BCs are periodic.
Mixed periodic/nonperiodic boundaries along the same boundary do not make
sense. The function checks that the boundary is either periodic or not and
throws an error if its mixed.
Args:
*arrays: a list of gridvariables.
Returns:
a list of types of boundaries corresponding to each axis if
they are consistent.
"""
bc_types = []
for axis in range(arrays[0].grid.ndim):
bcs = {is_periodic_boundary_conditions(array, axis) for array in arrays}
if len(bcs) != 1:
raise grids.InconsistentBoundaryConditionsError(
f'arrays do not have consistent bc: {arrays}')
elif bcs.pop():
bc_types.append('periodic')
else:
bc_types.append('nonperiodic')
return tuple(bc_types)
def get_pressure_bc_from_velocity(
v: GridVariableVector) -> HomogeneousBoundaryConditions:
"""Returns pressure boundary conditions for the specified velocity."""
# assumes that if the boundary is not periodic, pressure BC is zero flux.
velocity_bc_types = consistent_boundary_conditions(*v)
pressure_bc_types = []
for velocity_bc_type in velocity_bc_types:
if velocity_bc_type == 'periodic':
pressure_bc_types.append((BCType.PERIODIC, BCType.PERIODIC))
else:
pressure_bc_types.append((BCType.NEUMANN, BCType.NEUMANN))
return HomogeneousBoundaryConditions(pressure_bc_types)
def get_advection_flux_bc_from_velocity_and_scalar(
u: GridVariable, c: GridVariable,
flux_direction: int) -> ConstantBoundaryConditions:
"""Returns advection flux boundary conditions for the specified velocity.
Infers advection flux boundary condition in flux direction
from scalar c and velocity u in direction flux_direction.
The flux boundary condition should be used only to compute divergence.
If the boundaries are periodic, flux is periodic.
In nonperiodic case, flux boundary parallel to flux direction is
homogeneous dirichlet.
In nonperiodic case if flux direction is normal to the wall, the
function supports 2 cases:
1) Nonporous boundary, corresponding to homogeneous flux bc.
2) Pourous boundary with constant flux, corresponding to
both the velocity and scalar with Homogeneous Neumann bc.
This function supports only these cases because all other cases result in
time dependent flux boundary condition.
Args:
u: velocity component in flux_direction.
c: scalar to advect.
flux_direction: direction of velocity.
Returns:
BoundaryCondition instance for advection flux of c in flux_direction.
"""
# only no penetration and periodic boundaries are supported.
flux_bc_types = []
flux_bc_values = []
if not isinstance(u.bc, HomogeneousBoundaryConditions):
raise NotImplementedError(
f'Flux boundary condition is not implemented for velocity with {u.bc}')
for axis in range(c.grid.ndim):
if u.bc.types[axis][0] == 'periodic':
flux_bc_types.append((BCType.PERIODIC, BCType.PERIODIC))
flux_bc_values.append((None, None))
elif flux_direction != axis:
# This is not technically correct. Flux boundary condition in most cases
# is a time dependent function of the current values of the scalar
# and velocity. However, because flux is used only to take divergence, the
# boundary condition on the flux along the boundary parallel to the flux
# direction has no influence on the computed divergence because the
# boundary condition only alters ghost cells, while divergence is computed
# on the interior.
# To simplify the code and allow for flux to be wrapped in gridVariable,
# we are setting the boundary to homogeneous dirichlet.
# Note that this will not work if flux is used in any other capacity but
# to take divergence.
flux_bc_types.append((BCType.DIRICHLET, BCType.DIRICHLET))
flux_bc_values.append((0.0, 0.0))
else:
flux_bc_types_ax = []
flux_bc_values_ax = []
for i in range(2): # lower and upper boundary.
# case 1: nonpourous boundary
if (u.bc.types[axis][i] == BCType.DIRICHLET and
u.bc.bc_values[axis][i] == 0.0):
flux_bc_types_ax.append(BCType.DIRICHLET)
flux_bc_values_ax.append(0.0)
# case 2: zero flux boundary
elif (u.bc.types[axis][i] == BCType.NEUMANN and
c.bc.types[axis][i] == BCType.NEUMANN):
if not isinstance(c.bc, ConstantBoundaryConditions):
raise NotImplementedError(
'Flux boundary condition is not implemented for scalar' +
f' with {c.bc}')
if not np.isclose(c.bc.bc_values[axis][i], 0.0):
raise NotImplementedError(
'Flux boundary condition is not implemented for scalar' +
f' with {c.bc}')
flux_bc_types_ax.append(BCType.NEUMANN)
flux_bc_values_ax.append(0.0)
# no other case is supported
else:
raise NotImplementedError(
f'Flux boundary condition is not implemented for {u.bc, c.bc}')
flux_bc_types.append(flux_bc_types_ax)
flux_bc_values.append(flux_bc_values_ax)
return ConstantBoundaryConditions(flux_bc_types, flux_bc_values)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.boundaries."""
# TODO(jamieas): Consider updating these tests using the `hypothesis` framework.
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
BCType = boundaries.BCType
class ConstantBoundaryConditionsTest(test_util.TestCase):
def test_init_usage(self):
with self.subTest('init 1d'):
bc = boundaries.HomogeneousBoundaryConditions(
((BCType.PERIODIC, BCType.PERIODIC)))
self.assertEqual(bc.types, (('periodic', 'periodic')))
with self.subTest('init 2d'):
bc = boundaries.HomogeneousBoundaryConditions([
(BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)
])
self.assertEqual(bc.types, (
('periodic', 'periodic'),
('dirichlet', 'dirichlet'),
))
with self.subTest('periodic bc utility'):
bc = boundaries.periodic_boundary_conditions(ndim=3)
self.assertEqual(bc.types, (
('periodic', 'periodic'),
('periodic', 'periodic'),
('periodic', 'periodic'),
))
with self.subTest('dirichlet bc utility'):
bc = boundaries.dirichlet_boundary_conditions(ndim=3)
self.assertEqual(bc.types, (
('dirichlet', 'dirichlet'),
('dirichlet', 'dirichlet'),
('dirichlet', 'dirichlet'),
))
with self.subTest('neumann bc utility'):
bc = boundaries.neumann_boundary_conditions(ndim=3)
self.assertEqual(bc.types, (
('neumann', 'neumann'),
('neumann', 'neumann'),
('neumann', 'neumann'),
))
with self.subTest('channel flow 2d bc utility'):
bc = boundaries.channel_flow_boundary_conditions(ndim=2)
self.assertEqual(bc.types, (
('periodic', 'periodic'),
('dirichlet', 'dirichlet'),
))
with self.subTest('channel flow 3d bc utility'):
bc = boundaries.channel_flow_boundary_conditions(ndim=3)
self.assertEqual(bc.types, (
('periodic', 'periodic'),
('dirichlet', 'dirichlet'),
('periodic', 'periodic'),
))
with self.subTest('periodic and neumann bc utility'):
bc = boundaries.periodic_and_neumann_boundary_conditions()
self.assertEqual(bc.types, (
('periodic', 'periodic'),
('neumann', 'neumann'),
))
@parameterized.parameters(
dict(
shape=(11,),
initial_offset=(0.0,),
step=1,
offset=(0,),
),
dict(
shape=(11,),
initial_offset=(0.0,),
step=1,
offset=(1,),
),
dict(
shape=(11,),
initial_offset=(0.0,),
step=1,
offset=(-1,),
),
dict(
shape=(11,),
initial_offset=(0.0,),
step=1,
offset=(5,),
),
dict(
shape=(11,),
initial_offset=(0.0,),
step=1,
offset=(13,),
),
dict(
shape=(11,),
initial_offset=(0.0,),
step=1,
offset=(31,),
),
dict(
shape=(11, 12, 17),
initial_offset=(-0.5, -1.0, 1.0),
step=0.1,
offset=(-236, 10001, 3),
),
dict(
shape=(121,),
initial_offset=(-0.5,),
step=1,
offset=(31,),
),
dict(
shape=(11, 12, 17),
initial_offset=(0.5, 0.0, 1.0),
step=0.1,
offset=(-236, 10001, 3),
),
)
def test_shift_periodic(self, shape, initial_offset, step, offset):
"""Test that `shift` returns the expected values for periodic BC."""
grid = grids.Grid(shape, step)
data = np.arange(np.prod(shape)).reshape(shape)
array = grids.GridArray(data, initial_offset, grid)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
shifted_array = array
for axis, o in enumerate(offset):
shifted_array = bc.shift(shifted_array, o, axis)
shifted_indices = [(jnp.arange(s) + o) % s for s, o in zip(shape, offset)]
shifted_mesh = jnp.meshgrid(*shifted_indices, indexing='ij')
expected_offset = tuple(i + o for i, o in zip(initial_offset, offset))
expected = grids.GridArray(data[tuple(shifted_mesh)], expected_offset, grid)
self.assertArrayEqual(shifted_array, expected)
@parameterized.parameters(
# Periodic BC
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
shift_offset=-2,
expected_data=np.array([13, 14, 11, 12]),
expected_offset=(-2,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
shift_offset=-1,
expected_data=np.array([14, 11, 12, 13]),
expected_offset=(-1,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
shift_offset=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
shift_offset=1,
expected_data=np.array([12, 13, 14, 11]),
expected_offset=(1,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
shift_offset=2,
expected_data=np.array([13, 14, 11, 12]),
expected_offset=(2,),
),
# Dirichlet BC
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([11, 12, 13, 0]),
input_offset=(1,),
shift_offset=-1,
expected_data=np.array([0, 11, 12, 13]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
shift_offset=0,
expected_data=np.array([0, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
shift_offset=1,
expected_data=np.array([12, 13, 14, 0]),
expected_offset=(1,),
),
# Neumann BC
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
shift_offset=-1,
expected_data=np.array([11, 11, 12, 13]),
expected_offset=(-.5,),
),
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
shift_offset=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0.5,),
),
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
shift_offset=1,
expected_data=np.array([12, 13, 14, 14]),
expected_offset=(1.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=((BCType.DIRICHLET, BCType.NEUMANN),),
input_data=np.array([0, 12, 13, 14, 15]),
input_offset=(0.5,),
shift_offset=-1,
expected_data=np.array([0, 0, 12, 13, 14]),
expected_offset=(-.5,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.NEUMANN),),
input_data=np.array([0, 12, 13, 14, 15]),
input_offset=(0.5,),
shift_offset=1,
expected_data=np.array([12, 13, 14, 15, 15]),
expected_offset=(1.5,),
),
)
def test_shift_1d(self, bc_types, input_data, input_offset, shift_offset,
expected_data, expected_offset):
grid = grids.Grid(input_data.shape)
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.HomogeneousBoundaryConditions(bc_types)
actual = bc.shift(array, shift_offset, axis=0)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Periodic BC
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=-2,
expected_data=np.array([13, 14, 11, 12, 13, 14]),
expected_offset=(-2,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=-1,
expected_data=np.array([14, 11, 12, 13, 14]),
expected_offset=(-1,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=1,
expected_data=np.array([11, 12, 13, 14, 11]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=2,
expected_data=np.array([11, 12, 13, 14, 11, 12]),
expected_offset=(0,),
),
# Dirichlet BC
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
width=0,
expected_data=np.array([0, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
width=1,
expected_data=np.array([0, 12, 13, 14, 0]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([12, 13, 14, 0]),
input_offset=(1,),
width=-1,
expected_data=np.array([0, 12, 13, 14, 0]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
width=1,
expected_data=np.array([0, 12, 13, 14, 0]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=1,
expected_data=np.array([11, 12, 13, 14, -14]),
expected_offset=(0.5,),
),
# Neumann BC
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
width=-1,
expected_data=np.array([11, 11, 12, 13, 14]),
expected_offset=(-.5,),
),
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0.5,),
),
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
width=-1,
expected_data=np.array([11, 11, 12, 13, 14]),
expected_offset=(-.5,),
),
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
width=1,
expected_data=np.array([11, 12, 13, 14, 14]),
expected_offset=(.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=((BCType.DIRICHLET, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=-1,
expected_data=np.array([-11, 11, 12, 13, 14]),
expected_offset=(-.5,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=1,
expected_data=np.array([11, 12, 13, 14, 14]),
expected_offset=(.5,),
),
)
def test_pad_1d_no_mode(self, bc_types, input_data, input_offset, width,
expected_data, expected_offset):
grid = grids.Grid(input_data.shape)
input_data = input_data.astype('float')
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.HomogeneousBoundaryConditions(bc_types)
actual = bc.pad(array, width, axis=0)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
width=-3,
expected_data=np.array([-14, -13, -12, 0, 12, 13, 14]),
expected_offset=(-3,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
width=3,
expected_data=np.array([0, 12, 13, 14, 0, -14, -13]),
expected_offset=(0,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=3,
expected_data=np.array([11, 12, 13, 14, -14, -13, -12]),
expected_offset=(0.5,),
),
dict(
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
input_data=np.array([11, 12, 13, 0]),
input_offset=(1.,),
width=3,
expected_data=np.array([11, 12, 13, 0, -13, -12, -11]),
expected_offset=(1.,),
),
dict(
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
width=2,
expected_data=np.array([11, 12, 13, 14, 14, 13]),
expected_offset=(.5,),
),
)
def test_pad_1d_mirror(self, bc_types, input_data, input_offset, width,
expected_data, expected_offset):
grid = grids.Grid(input_data.shape)
input_data = input_data.astype('float')
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.HomogeneousBoundaryConditions(bc_types)
actual = bc.pad(array, width, axis=0, mode=boundaries.Padding.MIRROR)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([-13, -12, -11, 1, 12, 13, 14]),
input_offset=(-3,),
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([1, 12, 13, 14, 2, -12, -11]),
input_offset=(0,),
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
# Neumann BC
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([10, 11, 12, 13, 14]),
input_offset=(-.5,),
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(.5,),
),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([12, 13, 14, 15, 13]),
input_offset=(.5,),
expected_data=np.array([12, 13, 14, 15]),
expected_offset=(.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([-11, 12, 13, 14, 12]),
input_offset=(-.5,),
expected_data=np.array([12, 13, 14, 12]),
expected_offset=(.5,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
grid_shape=4,
input_data=np.array([12, 13, 14, 12]),
input_offset=(.5,),
expected_data=np.array([12, 13, 14, 12]),
expected_offset=(.5,),
),
# Periodic BC
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),), (None,)),
grid_shape=4,
input_data=np.array([-12, 11, 12, 13, 14]),
input_offset=(-1,),
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),), (None,)),
grid_shape=4,
input_data=np.array([11, 12, 13, 14, 12]),
input_offset=(0,),
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
)
def test_trim_padding_1d(
self,
grid_shape,
input_data,
input_offset,
bc_types,
expected_data,
expected_offset,
):
grid = grids.Grid((grid_shape,))
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual, _ = bc._trim_padding(array, axis=0)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([-14, -13, -12, 11, 12, 13, 14]),
input_offset=(-3,),
grid_size=4,
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
grid_size=4,
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14, 2, -12, -11]),
input_offset=(0,),
grid_size=4,
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14, 2, -12, -11]),
input_offset=(1,),
grid_size=5,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14, -12, -11]),
input_offset=(.5,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(.5,),
),
# Neumann BC
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(.5,),
),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14, 12]),
input_offset=(-.5,),
grid_size=4,
expected_data=np.array([12, 13, 14, 12]),
expected_offset=(.5,),
),
# Periodic BC
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),), (None,)),
input_data=np.array([-12, 11, 12, 13, 14]),
input_offset=(-1,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),), (None,)),
input_data=np.array([11, 12, 13, 14, 12]),
input_offset=(0,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
)
def test_trim_boundary_1d(
self,
input_data,
input_offset,
grid_size,
bc_types,
expected_data,
expected_offset,
):
grid = grids.Grid((grid_size,))
input_data = input_data.astype(
'float') # tests fail with integer input_data
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual = bc.trim_boundary(array)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([12, 13, 14]),
input_offset=(1,),
grid_size=4,
expected_data=np.array([1, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(1,),
grid_size=5,
expected_data=np.array([11, 12, 13, 14, 2]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(.5,),
),
# Neumann BC
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0.5,),
),
# Periodic BC
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),), (None,)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
)
def test_pad_and_impose_bc_1d(
self,
input_data,
input_offset,
grid_size,
bc_types,
expected_data,
expected_offset,
):
grid = grids.Grid((grid_size,))
input_data = input_data.astype(
'float') # tests fail with integer input_data
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual = bc.pad_and_impose_bc(array, expected_offset).array
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([0, 12, 13, 14]),
input_offset=(0,),
grid_size=4,
expected_data=np.array([1, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14, 11]),
input_offset=(1,),
grid_size=5,
expected_data=np.array([11, 12, 13, 14, 2]),
expected_offset=(1,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(.5,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(.5,),
),
# Neumann BC
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0.5,),
),
# Periodic BC
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),), (None,)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
grid_size=4,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
)
def test_impose_bc_1d(
self,
input_data,
input_offset,
grid_size,
bc_types,
expected_data,
expected_offset,
):
grid = grids.Grid((grid_size,))
input_data = input_data.astype(
'float') # tests fail with integer input_data
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual = bc.impose_bc(array).array
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
width=-3,
expected_data=np.array([-13, -12, -11, 1, 12, 13, 14]),
expected_offset=(-3,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
width=0,
expected_data=np.array([1, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
width=3,
expected_data=np.array([1, 12, 13, 14, 2, -12, -11]),
expected_offset=(0,),
),
# Neumann BC
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=-2,
expected_data=np.array([11, 10, 11, 12, 13, 14]),
expected_offset=(-1.5,),
),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0.5,),
),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=2,
expected_data=np.array([11, 12, 13, 14, 16, 15]),
expected_offset=(0.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=-1,
expected_data=np.array([-9, 11, 12, 13, 14]),
expected_offset=(-0.5,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=1,
expected_data=np.array([11, 12, 13, 14, 16]),
expected_offset=(0.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=(2, 1),
expected_data=np.array([-10, -9, 11, 12, 13, 14, 16]),
expected_offset=(-1.5,),
),
)
def test_pad_1d_inhomogeneous(self, bc_types, input_data, input_offset, width,
expected_data, expected_offset):
grid = grids.Grid((4,))
input_data = input_data.astype(
'float') # tests fail with integer input_data
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual = bc.pad(array, width, axis=0, mode=boundaries.Padding.MIRROR)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
width=-3,
expected_data=np.array([1, 1, 1, 1, 12, 13, 14]),
expected_offset=(-3,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
width=0,
expected_data=np.array([1, 12, 13, 14]),
expected_offset=(0,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((1.0, 2.0),)),
input_data=np.array([1, 12, 13, 14]),
input_offset=(0,),
width=3,
expected_data=np.array([1, 12, 13, 14, 2, 2, 2]),
expected_offset=(0,),
),
# Neumann BC
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=-2,
expected_data=np.array([10, 10, 11, 12, 13, 14]),
expected_offset=(-1.5,),
),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0.5,),
),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=2,
expected_data=np.array([11, 12, 13, 14, 16, 16]),
expected_offset=(0.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=-1,
expected_data=np.array([-9, 11, 12, 13, 14]),
expected_offset=(-0.5,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=1,
expected_data=np.array([11, 12, 13, 14, 16]),
expected_offset=(0.5,),
),
# Dirichlet / Neumann BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=(2, 1),
expected_data=np.array([-9, -9, 11, 12, 13, 14, 16]),
expected_offset=(-1.5,),
),
)
def test_pad_1d_inhomogeneous_extend(self, bc_types, input_data, input_offset,
width, expected_data, expected_offset):
grid = grids.Grid((4,))
input_data = input_data.astype('float')
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual = bc.pad(array, width, 0, mode=boundaries.Padding.EXTEND)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=((2, 1),),
expected_data=np.array([-10, -9, 11, 12, 13, 14, 16]),
expected_offset=(-1.5,),
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.NEUMANN),), ((1.0, 2.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
width=((2, 0),),
expected_data=np.array([-10, -9, 11, 12, 13, 14]),
expected_offset=(-1.5,),
),
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)), ((0.0, 0.0),
(0.0, 0.0))),
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 1),
width=((1, 1), (1, 1)),
expected_data=np.array([
[0, 31, 32, 33, 0, -33],
[0, 11, 12, 13, 0, -13],
[0, 21, 22, 23, 0, -23],
[0, 31, 32, 33, 0, -33],
[0, 11, 12, 13, 0, -13],
]),
expected_offset=(-0.5, 0.),
),
)
def test_pad_all(self, bc_types, input_data, input_offset, width,
expected_data, expected_offset):
grid = grids.Grid(input_data.shape)
input_data = input_data.astype(
'float') # tests fail with integer input_data
expected_data = expected_data.astype(
'float') # tests fail with integer input_data
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
actual = bc.pad_all(array, width, mode=boundaries.Padding.MIRROR)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((10.0, 20.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
expected_offset=(0.0,),
step=1.,
expected_data=np.array([10., 1., 1., 1.])),
dict(
bc_types=(((BCType.NEUMANN, BCType.NEUMANN),), ((10.0, 20.0),)),
input_data=np.array([11, 12, 13, 14]),
input_offset=(0.5,),
expected_offset=(0.0,),
step=0.5,
expected_data=np.array([10., 2., 2., 2.])),
)
def test_neumann_boundary(self, bc_types, input_data, input_offset,
expected_offset, step, expected_data):
grid = grids.Grid((4,), step=step)
input_data = input_data.astype('float')
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])
grid_var = bc.impose_bc(array)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(fd.backward_difference(grid_var, 0), expected)
@parameterized.parameters(
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
width=-2,
axis=0,
expected_data=np.array([
[-21, -22, -23, -24],
[-11, -12, -13, -14],
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
expected_offset=(-1.5, 0.5),
),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
width=2,
axis=1,
expected_data=np.array([
[11, 12, 13, 14, -14, -13],
[21, 22, 23, 24, -24, -23],
[31, 32, 33, 34, -34, -33],
]),
expected_offset=(0.5, 0.5),
),
dict(
input_data=np.array([
[11, 12, 13, 0],
[21, 22, 23, 0],
[31, 32, 33, 0],
]),
input_offset=(0.5, 1), # edge aligned offset
width=-2,
axis=1,
expected_data=np.array([
[-11, 0, 11, 12, 13, 0],
[-21, 0, 21, 22, 23, 0],
[-31, 0, 31, 32, 33, 0],
]),
expected_offset=(0.5, -1),
),
)
def test_pad_dirichlet_cell_center(self, input_data, input_offset, width,
axis, expected_data, expected_offset):
grid = grids.Grid(input_data.shape)
input_data = input_data.astype(
'float') # tests fail with integer input_data
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
actual = bc.pad(array, width, axis, mode=boundaries.Padding.MIRROR)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
width=-2,
axis=0,
),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.0, 0.5),
width=-2,
axis=0,
),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(1.0, 0.5),
width=-2,
axis=0,
))
def test_pad_periodic_raises(self, input_data, input_offset, width, axis):
grid = grids.Grid((4, 4))
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
error_msg = 'the GridArray does not contain all interior grid values.'
with self.assertRaisesRegex(ValueError, error_msg):
_ = bc.pad(array, width, axis, mode=boundaries.Padding.MIRROR)
@parameterized.parameters(
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
width=-1,
axis=0,
),)
def test_pad_neumann_raises(self, input_data, input_offset, width, axis):
grid = grids.Grid((4, 4))
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.neumann_boundary_conditions(grid.ndim)
error_msg = 'the GridArray does not contain all interior grid values.'
with self.assertRaisesRegex(ValueError, error_msg):
_ = bc.pad(array, width, axis, mode=boundaries.Padding.MIRROR)
@parameterized.parameters(
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
width=-2,
axis=0,
),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.0, 0.5),
width=-2,
axis=0,
),
)
def test_pad_dirichlet_raises(self, input_data, input_offset, width, axis):
grid = grids.Grid((4, 4))
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
error_msg = 'the GridArray does not contain all interior grid values.'
with self.assertRaisesRegex(ValueError, error_msg):
_ = bc.pad(array, width, axis, mode=boundaries.Padding.MIRROR)
@parameterized.parameters(
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
values=((1.0, 2.0), (3.0, 4.0)),
width=-2,
axis=0,
expected_data=np.array([
[-19, -20, -21, -22],
[-9, -10, -11, -12],
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
expected_offset=(-1.5, 0.5),
),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
input_offset=(0.5, 0.5),
values=((1.0, 2.0), (3.0, 4.0)),
width=2,
axis=1,
expected_data=np.array([
[11, 12, 13, 14, -6, -5],
[21, 22, 23, 24, -16, -15],
[31, 32, 33, 34, -26, -25],
]),
expected_offset=(0.5, 0.5),
),
dict(
input_data=np.array([
[11, 12, 13, 4],
[21, 22, 23, 4],
[31, 32, 33, 4],
]),
input_offset=(0.5, 1), # edge aligned offset
values=((1.0, 2.0), (3.0, 4.0)),
width=-2,
axis=1,
expected_data=np.array([
[-8, 3, 11, 12, 13, 4],
[-18, 3, 21, 22, 23, 4],
[-28, 3, 31, 32, 33, 4],
]),
expected_offset=(0.5, -1),
),
)
def test_pad_dirichlet_cell_center_inhomogeneous(self, input_data,
input_offset, values, width,
axis, expected_data,
expected_offset):
input_data = input_data.astype(
'float') # tests fail with integer input_data
expected_data = expected_data.astype(
'float') # tests fail with integer input_data
grid = grids.Grid(input_data.shape)
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.dirichlet_boundary_conditions(grid.ndim, values)
actual = bc.pad(array, width, axis, mode=boundaries.Padding.MIRROR)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
dict(
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=-1,
expected_data=np.array([12, 13, 14]),
expected_offset=(1,),
),
dict(
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=0,
expected_data=np.array([11, 12, 13, 14]),
expected_offset=(0,),
),
dict(
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=1,
expected_data=np.array([11, 12, 13]),
expected_offset=(0,),
),
dict(
input_data=np.array([11, 12, 13, 14]),
input_offset=(0,),
width=2,
expected_data=np.array([11, 12]),
expected_offset=(0,),
),
)
def test_trim_1d(self, input_data, input_offset, width, expected_data,
expected_offset):
grid = grids.Grid(input_data.shape)
array = grids.GridArray(input_data, input_offset, grid)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
# Note: trim behavior does not depend on bc type
actual = bc._trim(array, width, axis=0)
expected = grids.GridArray(expected_data, expected_offset, grid)
self.assertArrayEqual(actual, expected)
@parameterized.parameters(
dict(
values=((1.0, 2.0),), axis=0, shape=(3,), expected_values=(1.0, 2.0)),
dict(
values=((1.0, 2.0), (3.0, 4.0)),
axis=0,
shape=(3, 4),
expected_values=((1.0, 1.0, 1.0, 1.0), (2.0, 2.0, 2.0, 2.0))),
dict(
values=((1.0, 2.0), (3.0, 4.0)),
axis=1,
shape=(3, 4),
expected_values=((3.0, 3.0, 3.0), (4.0, 4.0, 4.0))),
)
def test_values_constant_boundary(self, values, axis, shape, expected_values):
grid = grids.Grid(shape)
bc = boundaries.dirichlet_boundary_conditions(grid.ndim, values)
actual = bc.values(axis, grid)
self.assertArrayEqual(actual, expected_values)
self.assertIsInstance(actual, tuple)
for x in actual:
self.assertIsInstance(x, jnp.ndarray)
@parameterized.parameters(
dict(axis=0, shape=(3,), expected_values=(0.0, 0.0)),
dict(
axis=0,
shape=(3, 4),
expected_values=((0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0))),
dict(
axis=1,
shape=(3, 4),
expected_values=((0.0, 0.0, 0.0), (0.0, 0.0, 0.0))),
)
def test_values_homogeneous_boundary(self, axis, shape, expected_values):
grid = grids.Grid(shape)
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
actual = bc.values(axis, grid)
self.assertArrayEqual(actual, expected_values)
self.assertIsInstance(actual, tuple)
for x in actual:
self.assertIsInstance(x, jnp.ndarray)
@parameterized.parameters(
dict(
input_data=np.array([11, 12, 13, 14]),
offset=(0.0,),
values=((1.0, 2.0),),
expected_data=np.array([1, 12, 13, 14])),
dict(
input_data=np.array([11, 12, 13, 14]),
offset=(1.0,),
values=((1.0, 2.0),),
expected_data=np.array([11, 12, 13, 2])),
dict(
input_data=np.array([11, 12, 13, 14]),
offset=(0.5,),
values=((1.0, 2.0),),
expected_data=np.array([11, 12, 13, 14])),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
offset=(1.0, 0.5),
values=((1.0, 2.0), (3.0, 4.0)),
expected_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[2, 2, 2, 2],
])),
dict(
input_data=np.array([
[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
]),
offset=(0.5, 0.0),
values=((1.0, 2.0), (3.0, 4.0)),
expected_data=np.array([
[3, 12, 13, 14],
[3, 22, 23, 24],
[3, 32, 33, 34],
])),
)
def test_impose_bc_constant_boundary(
self, input_data, offset, values, expected_data):
grid = grids.Grid(input_data.shape)
array = grids.GridArray(input_data, offset, grid)
bc = boundaries.dirichlet_boundary_conditions(grid.ndim, values)
variable = grids.GridVariable(array, bc)
variable = variable.impose_bc()
expected = grids.GridArray(expected_data, offset, grid)
self.assertArrayEqual(variable.array, expected)
def test_has_all_periodic_boundary_conditions(self):
grid = grids.Grid((10, 10))
array = grids.GridArray(np.zeros((10, 10)), (0.5, 0.5), grid)
periodic_bc = boundaries.periodic_boundary_conditions(ndim=2)
nonperiodic_bc = boundaries.periodic_and_neumann_boundary_conditions()
with self.subTest('returns True'):
c = grids.GridVariable(array, periodic_bc)
v = (grids.GridVariable(array, periodic_bc),
grids.GridVariable(array, periodic_bc))
self.assertTrue(boundaries.has_all_periodic_boundary_conditions(c, *v))
with self.subTest('returns False'):
c = grids.GridVariable(array, periodic_bc)
v = (grids.GridVariable(array, periodic_bc),
grids.GridVariable(array, nonperiodic_bc))
self.assertFalse(boundaries.has_all_periodic_boundary_conditions(c, *v))
def test_get_pressure_bc_from_velocity_2d(self):
grid = grids.Grid((10, 10))
u_array = grids.GridArray(jnp.zeros(grid.shape), (1, 0.5), grid)
v_array = grids.GridArray(jnp.zeros(grid.shape), (0.5, 1), grid)
velocity_bc = boundaries.channel_flow_boundary_conditions(ndim=2)
v = (grids.GridVariable(u_array, velocity_bc),
grids.GridVariable(v_array, velocity_bc))
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
self.assertEqual(pressure_bc.types, ((BCType.PERIODIC, BCType.PERIODIC),
(BCType.NEUMANN, BCType.NEUMANN)))
def test_get_pressure_bc_from_velocity_3d(self):
grid = grids.Grid((10, 10, 10))
u_array = grids.GridArray(jnp.zeros(grid.shape), (1, 0.5, 0.5), grid)
v_array = grids.GridArray(jnp.zeros(grid.shape), (0.5, 1, 0.5), grid)
w_array = grids.GridArray(jnp.zeros(grid.shape), (0.5, 0.5, 1), grid)
velocity_bc = boundaries.channel_flow_boundary_conditions(ndim=3)
v = (grids.GridVariable(u_array, velocity_bc),
grids.GridVariable(v_array, velocity_bc),
grids.GridVariable(w_array, velocity_bc))
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
self.assertEqual(pressure_bc.types, ((BCType.PERIODIC, BCType.PERIODIC),
(BCType.NEUMANN, BCType.NEUMANN),
(BCType.PERIODIC, BCType.PERIODIC)))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(pnorgaard) Implement bicgstab for non-symmetric operators
"""Module for functionality related to diffusion."""
from typing import Optional, Tuple
import jax.numpy as jnp
import jax.scipy.sparse.linalg
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
Array = grids.Array
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
def diffuse(c: GridVariable, nu: float) -> GridArray:
"""Returns the rate of change in a concentration `c` due to diffusion."""
return nu * fd.laplacian(c)
def stable_time_step(viscosity: float, grid: grids.Grid) -> float:
"""Calculate a stable time step size for explicit diffusion.
The calculation is based on analysis of central-time-central-space (CTCS)
schemes.
Args:
viscosity: kinematic visosity
grid: a `Grid` object.
Returns:
The prescribed time interval.
"""
if viscosity == 0:
return float('inf')
dx = min(grid.step)
ndim = grid.ndim
return dx ** 2 / (viscosity * 2 ** ndim)
def _subtract_linear_part_dirichlet(
c_data: Array,
grid: grids.Grid,
axis: int,
offset: Tuple[float, float],
bc_values: Tuple[float, float],
) -> Array:
"""Transforms c_data such that c_data satisfies dirichlet boundary.
The function subtracts a linear function from c_data s.t. the returned
array has homogeneous dirichlet boundaries. Note that this assumes c_data has
constant dirichlet boundary values.
Args:
c_data: right-hand-side of diffusion equation.
grid: grid object
axis: axis along which to impose boundary transformation
offset: offset of the right-hand-side
bc_values: boundary values along axis
Returns:
transformed right-hand-side
"""
def _update_rhs_along_axis(arr_1d, linear_part):
arr_1d = arr_1d - linear_part
return arr_1d
lower_value, upper_value = bc_values
y = grid.mesh(offset)[axis][0]
one_d_grid = grids.Grid((grid.shape[axis],), domain=(grid.domain[axis],))
y_boundary = boundaries.dirichlet_boundary_conditions(ndim=1)
y = y_boundary.trim_boundary(grids.GridArray(y, (offset[axis],),
one_d_grid)).data
domain_length = (grid.domain[axis][1] - grid.domain[axis][0])
domain_start = grid.domain[axis][0]
linear_part = lower_value + (upper_value - lower_value) * (
y - domain_start) / domain_length
c_data = jnp.apply_along_axis(
_update_rhs_along_axis, axis, c_data, linear_part)
return c_data
def _rhs_transform(
u: grids.GridArray,
bc: boundaries.BoundaryConditions,
) -> Array:
"""Transforms the RHS of diffusion equation.
In case of constant dirichlet boundary conditions for heat equation
the linear term is subtracted. See diffusion.solve_fast_diag.
Args:
u: a GridArray that solves ∇²x = ∇²u for x.
bc: specifies boundary of u.
Returns:
u' s.t. u = u' + w where u' has 0 dirichlet bc and w is linear.
"""
if not isinstance(bc, boundaries.ConstantBoundaryConditions):
raise NotImplementedError(
f'transformation cannot be done for this {bc}.')
u_data = u.data
for axis in range(u.grid.ndim):
for i, _ in enumerate(['lower', 'upper']): # lower and upper boundary
if bc.types[axis][i] == boundaries.BCType.DIRICHLET:
bc_values = [0., 0.]
bc_values[i] = bc.bc_values[axis][i]
u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset,
bc_values)
elif bc.types[axis][i] == boundaries.BCType.NEUMANN:
if any(bc.bc_values[axis]):
raise NotImplementedError(
'transformation is not implemented for inhomogeneous Neumann bc.')
return u_data
def solve_cg(v: GridVariableVector,
nu: float,
dt: float,
rtol: float = 1e-6,
atol: float = 1e-6,
maxiter: Optional[int] = None) -> GridVariableVector:
"""Conjugate gradient solve for diffusion."""
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError('solve_cg() expects periodic BC')
def solve_component(u: GridVariable) -> GridArray:
"""Solves (1 - ν Δt ∇²) u_{t+1} = u_{tilda} for u_{t+1}."""
def linear_op(u_new: GridArray) -> GridArray:
"""Linear operator for (1 - ν Δt ∇²) u_{t+1}."""
u_new = grids.GridVariable(u_new, u.bc) # get boundary condition from u
return u_new.array - dt * nu * fd.laplacian(u_new)
def cg(b: GridArray, x0: GridArray) -> GridArray:
"""Iteratively solves Lx = b. with initial guess x0."""
x, _ = jax.scipy.sparse.linalg.cg(
linear_op, b, x0=x0, tol=rtol, atol=atol, maxiter=maxiter)
return x
return cg(u.array, u.array)
return tuple(grids.GridVariable(solve_component(u), u.bc) for u in v)
def solve_fast_diag(
v: GridVariableVector,
nu: float,
dt: float,
implementation: Optional[str] = None,
) -> GridVariableVector:
"""Solve for diffusion using the fast diagonalization approach."""
# We reuse eigenvectors from the Laplacian and transform the eigenvalues
# because this is better conditioned than directly diagonalizing 1 - ν Δt ∇²
# when ν Δt is small.
def func(x):
dt_nu_x = (dt * nu) * x
return dt_nu_x / (1 - dt_nu_x)
# Compute (1 - ν Δt ∇²)⁻¹ u as u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u, for less
# error when ν Δt is small.
# If dirichlet bc are supplied: only works for dirichlet bc that are linear
# functions on the boundary. Then u = u' + w where u' has 0 dirichlet bc and
# w is linear. Then u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u = u +
# (1 - ν Δt ∇²)⁻¹(ν Δt ∇²)u'. The function _rhs_transform subtracts
# the linear part s.t. fast_diagonalization solves
# u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u'.
v_diffused = list()
if boundaries.has_all_periodic_boundary_conditions(*v):
circulant = True
else:
circulant = False
# only matmul implementation supports non-circulant matrices
implementation = 'matmul'
for u in v:
laplacians = array_utils.laplacian_matrix_w_boundaries(
u.grid, u.offset, u.bc)
op = fast_diagonalization.transform(
func,
laplacians,
v[0].dtype,
hermitian=True,
circulant=circulant,
implementation=implementation)
u_interior = u.bc.trim_boundary(u.array)
u_interior_transformed = _rhs_transform(u_interior, u.bc)
u_dt_diffused = grids.GridArray(
op(u_interior_transformed), u_interior.offset, u_interior.grid)
u_diffused = u_interior + u_dt_diffused
u_diffused = u.bc.pad_and_impose_bc(u_diffused, offset_to_pad_to=u.offset)
v_diffused.append(u_diffused)
return tuple(v_diffused)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.diffusion."""
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import diffusion
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
class DiffusionTest(test_util.TestCase):
"""Some simple sanity tests for diffusion on constant fields."""
def diffusion_setup(self, bc, offset):
nu = 1.0
dt = 0.5
rs = np.random.RandomState(0)
b = rs.randn(4, 4).astype(np.float32)
grid = grids.Grid((4, 4), domain=((0, 4), (0, 4))) # has step = 1.0
b = bc.impose_bc(grids.GridArray(b, offset, grid))
x = diffusion.solve_fast_diag((b,), nu, dt)[0]
# laplacian is defined only on the interior
x_interior = grids.GridVariable(bc.trim_boundary(x.array), bc)
x = x_interior.array - nu * dt * fd.laplacian(x_interior)
return x, b
def test_explicit_diffusion(self):
nu = 1.
shape = (101, 101, 101)
offset = (0.5, 0.5, 0.5)
step = (1., 1., 1.)
grid = grids.Grid(shape, step)
c = grids.GridVariable(
array=grids.GridArray(jnp.ones(shape), offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
diffused = diffusion.diffuse(c, nu)
expected = grids.GridArray(jnp.zeros_like(diffused.data), offset, grid)
self.assertAllClose(expected, diffused)
@parameterized.parameters(
dict(solve=diffusion.solve_cg, atol=1e-6),
dict(solve=diffusion.solve_fast_diag, atol=1e-6),
)
def test_implicit_diffusion(self, solve, atol):
nu = 1.
dt = 0.1
shape = (100, 100)
grid = grids.Grid(shape, step=1)
periodic_bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = (
grids.GridVariable(
grids.GridArray(jnp.ones(shape), (1, 0.5), grid), periodic_bc),
grids.GridVariable(
grids.GridArray(jnp.ones(shape), (0.5, 1), grid), periodic_bc),
)
actual = solve(v, nu, dt)
expected = v
self.assertAllClose(expected[0], actual[0], atol=atol)
self.assertAllClose(expected[1], actual[1], atol=atol)
@parameterized.parameters(((1.0, 0.5), 0.0), ((1.0, 1.0), 0.0),
((1.0, 0.0), 0.0), ((1.0, 0.0), 0.0),
((1.0, 0.5), 1.0), ((1.0, 1.0), 1.0),
((1.0, 0.0), 1.0))
def test_diffusion_2d_periodic_and_dirichlet(self, offset, value_lo):
bc = boundaries.periodic_and_dirichlet_boundary_conditions((value_lo, 0.0))
x, b = self.diffusion_setup(bc, offset)
self.assertAllClose(
x.data,
bc.trim_boundary(b).data,
atol=1e-5)
self.assertArrayEqual(x.grid, b.grid)
@parameterized.parameters(((1.0, 0.5),), ((0.5, 0.5),))
def test_diffusion_2d_periodic_and_neumann(self, offset):
bc = boundaries.periodic_and_neumann_boundary_conditions()
x, b = self.diffusion_setup(bc, offset)
self.assertAllClose(
x.data,
bc.trim_boundary(b).data,
atol=1e-5)
self.assertArrayEqual(x.grid, b.grid)
@parameterized.parameters(((0.5, 0.5),))
def test_diffusion_2d_fully_neumann(self, offset):
bc = boundaries.neumann_boundary_conditions(2)
x, b = self.diffusion_setup(bc, offset)
self.assertAllClose(
x.data,
bc.trim_boundary(b).data,
atol=1e-5)
self.assertArrayEqual(x.grid, b.grid)
@parameterized.parameters(((1.0, 0.5),), ((1.0, 1.0),), ((1.0, 0.0),))
def test_diffusion_2d_fully_dirichlet(self, offset):
bc = boundaries.dirichlet_boundary_conditions(2)
x, b = self.diffusion_setup(bc, offset)
self.assertAllClose(
x.data,
bc.trim_boundary(b).data,
atol=1e-5)
self.assertArrayEqual(x.grid, b.grid)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Examples of defining equations."""
import functools
from typing import Callable, Optional
import jax
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import diffusion
from jax_cfd.base import grids
from jax_cfd.base import pressure
from jax_cfd.base import time_stepping
import tree_math
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
ConvectFn = Callable[[GridVariableVector], GridArrayVector]
DiffuseFn = Callable[[GridVariable, float], GridArray]
ForcingFn = Callable[[GridVariableVector], GridArrayVector]
def sum_fields(*args):
# return jax.tree_util.tree_map(lambda *a: sum(a), *args)
return jax.tree_util.tree_map(lambda *a: sum(a), *args)
def stable_time_step(
max_velocity: float,
max_courant_number: float,
viscosity: float,
grid: grids.Grid,
implicit_diffusion: bool = False,
) -> float:
"""Calculate a stable time step for Navier-Stokes."""
dt = advection.stable_time_step(max_velocity, max_courant_number, grid)
if not implicit_diffusion:
diffusion_dt = diffusion.stable_time_step(viscosity, grid)
if diffusion_dt < dt:
raise ValueError(f'stable time step for diffusion is smaller than '
f'the chosen timestep: {diffusion_dt} vs {dt}')
return dt
def dynamic_time_step(v: GridVariableVector,
max_courant_number: float,
viscosity: float,
grid: grids.Grid,
implicit_diffusion: bool = False) -> float:
"""Pick a dynamic time-step for Navier-Stokes based on stable advection."""
v_max = jnp.sqrt(jnp.max(sum(u.data ** 2 for u in v)))
return stable_time_step( # pytype: disable=wrong-arg-types # jax-types
v_max, max_courant_number, viscosity, grid, implicit_diffusion)
def _wrap_term_as_vector(fun, *, name):
return tree_math.unwrap(jax.named_call(fun, name=name), vector_argnums=0)
def navier_stokes_explicit_terms(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffuse: DiffuseFn = diffusion.diffuse,
forcing: Optional[ForcingFn] = None,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
del grid # unused
if convect is None:
def convect(v): # pylint: disable=function-redefined
return tuple(
advection.advect_van_leer_using_limiters(u, v, dt) for u in v)
def diffuse_velocity(v, *args):
return tuple(diffuse(u, *args) for u in v)
convection = _wrap_term_as_vector(convect, name='convection')
diffusion_ = _wrap_term_as_vector(diffuse_velocity, name='diffusion')
if forcing is not None:
forcing = _wrap_term_as_vector(forcing, name='forcing')
@tree_math.wrap
@functools.partial(jax.named_call, name='navier_stokes_momentum')
def _explicit_terms(v):
dv_dt = convection(v)
if viscosity is not None:
dv_dt += diffusion_(v, viscosity / density)
if forcing is not None:
dv_dt += forcing(v) / density
return dv_dt
def explicit_terms_with_same_bcs(v):
dv_dt = _explicit_terms(v)
return tuple(grids.GridVariable(a, u.bc) for a, u in zip(dv_dt, v))
return explicit_terms_with_same_bcs
# TODO(shoyer): rename this to explicit_diffusion_navier_stokes
def semi_implicit_navier_stokes(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffuse: DiffuseFn = diffusion.diffuse,
pressure_solve: Callable = pressure.solve_fast_diag,
forcing: Optional[ForcingFn] = None,
time_stepper: Callable = time_stepping.forward_euler,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
explicit_terms = navier_stokes_explicit_terms(
density=density,
viscosity=viscosity,
dt=dt,
grid=grid,
convect=convect,
diffuse=diffuse,
forcing=forcing)
pressure_projection = jax.named_call(pressure.projection, name='pressure')
# TODO(jamieas): Consider a scheme where pressure calculations and
# advection/diffusion are staggered in time.
ode = time_stepping.ExplicitNavierStokesODE(
explicit_terms,
lambda v: pressure_projection(v, pressure_solve)
)
step_fn = time_stepper(ode, dt)
return step_fn
def implicit_diffusion_navier_stokes(
density: float,
viscosity: float,
dt: float,
grid: grids.Grid,
convect: Optional[ConvectFn] = None,
diffusion_solve: Callable = diffusion.solve_fast_diag,
pressure_solve: Callable = pressure.solve_fast_diag,
forcing: Optional[ForcingFn] = None,
) -> Callable[[GridVariableVector], GridVariableVector]:
"""Returns a function that performs a time step of Navier Stokes."""
del grid # unused
if convect is None:
def convect(v): # pylint: disable=function-redefined
return tuple(
advection.advect_van_leer_using_limiters(u, v, dt) for u in v)
convect = jax.named_call(convect, name='convection')
pressure_projection = jax.named_call(pressure.projection, name='pressure')
diffusion_solve = jax.named_call(diffusion_solve, name='diffusion')
# TODO(shoyer): refactor to support optional higher-order time integators
@jax.named_call
def navier_stokes_step(v: GridVariableVector) -> GridVariableVector:
"""Computes state at time `t + dt` using first order time integration."""
convection = convect(v)
accelerations = [convection]
if forcing is not None:
# TODO(shoyer): include time in state?
f = forcing(v)
accelerations.append(tuple(f / density for f in f))
dvdt = sum_fields(*accelerations)
# Update v by taking a time step
v = tuple(
grids.GridVariable(u.array + dudt * dt, u.bc)
for u, dudt in zip(v, dvdt))
# Pressure projection to incompressible velocity field
v = pressure_projection(v, pressure_solve)
# Solve for implicit diffusion
v = diffusion_solve(v, viscosity, dt)
return v
return navier_stokes_step
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.equations."""
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import boundaries
from jax_cfd.base import diffusion
from jax_cfd.base import equations
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import pressure
from jax_cfd.base import test_util
from jax_cfd.base import time_stepping
import numpy as np
def zero_velocity_field(grid: grids.Grid) -> grids.GridVariableVector:
"""Returns an all-zero periodic velocity fields."""
return tuple(
grids.GridVariable(grids.GridArray(jnp.zeros(grid.shape), o, grid),
boundaries.periodic_boundary_conditions(grid.ndim))
for o in grid.cell_faces)
def sinusoidal_velocity_field(grid: grids.Grid) -> grids.GridVariableVector:
"""Returns a divergence-free velocity flow on `grid`."""
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
vs = tuple(
jnp.sin(2. * np.pi * g / s) for g, s in zip(grid.mesh(), mesh_size))
return tuple(
grids.GridVariable(grids.GridArray(v, o, grid),
boundaries.periodic_boundary_conditions(grid.ndim))
for v, o in zip(vs[1:] + vs[:1], grid.cell_faces))
def gaussian_force_field(grid):
"""Returns a 'Gaussian-shaped' force field in the 'x' direction."""
mesh = grid.mesh()
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
offsets = grid.cell_faces
v = [grids.GridArray(
jnp.exp(-sum([jnp.square(x / s - .5)
for x, s in zip(mesh, mesh_size)]) * 100.),
offsets[0], grid)]
for j in range(1, grid.ndim):
v.append(grids.GridArray(jnp.zeros(grid.shape), offsets[j], grid))
return tuple(v)
def gaussian_forcing(v: grids.GridVariableVector) -> grids.GridArrayVector:
"""Returns Gaussian field forcing."""
grid = grids.consistent_grid(*v)
return gaussian_force_field(grid)
def momentum(v: grids.GridVariableVector, density: float):
"""Returns the momentum due to velocity field `v`."""
grid = grids.consistent_grid(*v)
return jnp.array([u.data for u in v]).sum() * density * jnp.array(
grid.step).prod()
def _convect_upwind(v: grids.GridVariableVector) -> grids.GridArrayVector:
return tuple(advection.advect_upwind(u, v) for u in v)
class SemiImplicitNavierStokesTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='semi_implicit_sinusoidal_velocity_base',
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=None,
pressure_solve=pressure.solve_cg,
time_stepper=time_stepping.forward_euler,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=2e-3),
dict(testcase_name='semi_implicit_gaussian_force_upwind',
velocity=zero_velocity_field,
forcing=gaussian_forcing,
shape=(40, 40, 40),
step=(1., 1., 1.),
density=1.,
viscosity=None,
convect=_convect_upwind,
pressure_solve=pressure.solve_cg,
time_stepper=time_stepping.midpoint_rk2,
dt=1e-3,
time_steps=100,
divergence_atol=1e-4,
momentum_atol=2e-4),
dict(testcase_name='semi_implicit_sinusoidal_velocity_fast_diag',
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=advection.convect_linear,
pressure_solve=pressure.solve_fast_diag,
time_stepper=time_stepping.classic_rk4,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
)
def test_divergence_and_momentum(
self, velocity, forcing, shape, step, density, viscosity, convect,
pressure_solve, time_stepper, dt, time_steps, divergence_atol,
momentum_atol,
):
grid = grids.Grid(shape, step)
navier_stokes = equations.semi_implicit_navier_stokes(
density,
viscosity,
dt,
grid,
convect=convect,
pressure_solve=pressure_solve,
forcing=forcing,
time_stepper=time_stepper,
)
v_initial = velocity(grid)
v_final = funcutils.repeated(navier_stokes, time_steps)(v_initial)
divergence = fd.divergence(v_final)
self.assertLess(jnp.max(divergence.data), divergence_atol)
initial_momentum = momentum(v_initial, density)
final_momentum = momentum(v_final, density)
if forcing is not None:
expected_change = (
jnp.array([f.data for f in forcing(v_initial)]).sum() *
jnp.array(grid.step).prod() * dt * time_steps)
else:
expected_change = 0
self.assertAllClose(
initial_momentum + expected_change, final_momentum, atol=momentum_atol)
class ImplicitDiffusionNavierStokesTest(test_util.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='implicit_sinusoidal_velocity_base',
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=None,
diffusion_solve=diffusion.solve_cg,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=3e-3),
dict(
testcase_name='implicit_gaussian_force_upwind',
velocity=zero_velocity_field,
forcing=gaussian_forcing,
shape=(40, 40, 40),
step=(1., 1., 1.),
density=1.,
viscosity=1e-4,
convect=_convect_upwind,
diffusion_solve=diffusion.solve_cg,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=100,
divergence_atol=1e-4,
momentum_atol=9e-4),
dict(
testcase_name='implicit_sinusoidal_velocity_fast_diag',
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=advection.convect_linear,
diffusion_solve=diffusion.solve_fast_diag,
pressure_solve=pressure.solve_fast_diag,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
)
def test_divergence_and_momentum(
self, velocity, forcing, shape, step, density, viscosity, convect,
diffusion_solve, pressure_solve, dt, time_steps, divergence_atol,
momentum_atol,
):
grid = grids.Grid(shape, step)
navier_stokes = equations.implicit_diffusion_navier_stokes(
density,
viscosity,
dt,
grid,
convect=convect,
diffusion_solve=diffusion_solve,
pressure_solve=pressure_solve,
forcing=forcing)
v_initial = velocity(grid)
v_final = funcutils.repeated(navier_stokes, time_steps)(v_initial)
divergence = fd.divergence(v_final)
self.assertLess(jnp.max(divergence.data), divergence_atol)
initial_momentum = momentum(v_initial, density)
final_momentum = momentum(v_final, density)
if forcing is not None:
expected_change = (
jnp.array([f.data for f in forcing(v_initial)]).sum() *
jnp.array(grid.step).prod() * dt * time_steps)
else:
expected_change = 0
self.assertAllClose(
initial_momentum + expected_change, final_momentum, atol=momentum_atol)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast diagonalization method for inverting linear operators."""
import functools
from typing import Callable, Optional, Sequence, Union
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np
Array = Union[np.ndarray, jax.Array]
def transform(
func: Callable[[Array], Array],
operators: Sequence[np.ndarray],
dtype: np.dtype,
*,
hermitian: bool = False,
circulant: bool = False,
implementation: Optional[str] = None,
precision: lax.Precision = lax.Precision.HIGHEST,
) -> Callable[[Array], Array]:
"""Apply a linear operator written as a sum of operators on each axis.
Such linear operators are *separable*, and can be written as a sum of tensor
products, e.g., `operators = [A, B]` corresponds to the linear operator
A ⊗ I + I ⊗ B, where the tensor product ⊗ indicates a separation between
operators applied along the first and second axis.
This function computes matrix-valued functions of such linear operators via
the "fast diagonalization method" [1]:
F(A ⊗ I + I ⊗ B)
= (X(A) ⊗ X(B)) F(Λ(A) ⊗ I + I ⊗ Λ(B)) (X(A)^{-1} ⊗ X(B)^{-1})
where X(A) denotes the matrix of eigenvectors of A and Λ(A) denotes the
(diagonal) matrix of eigenvalues. The function `F` is easy to compute in
this basis, because matrix Λ(A) ⊗ I + I ⊗ Λ(B) is diagonal.
The current implementation directly diagonalizes dense matrices for each
linear operator, which limits it's applicability to grids with less than
1e3-1e4 elements per side (~1 second to several minutes of setup time).
Example: The Laplacian operator can be written as a sum of 1D Laplacian
operators along each axis, i.e., as a sum of 1D convolutions along each axis.
This can be seen mathematically (∇² = ∂²/∂x² + ∂²/∂y² + ∂²/∂z²) or by
decomposing the 2D kernel:
[0 1 0] [ 1]
[1 -4 1] = [1 -2 1] + [-2]
[0 1 0] [ 1]
Args:
func: NumPy function applied in the diagonal basis that is passed the
N-dimensional array of eigenvalues (one dimension for each linear
operator).
operators: forward linear operators as matrices, applied along each axis.
Each of these matrices is diagonalized.
dtype: dtype of the right-hand-side.
hermitian: whether or not all linear operator are Hermitian (i.e., symmetric
in the real valued case).
circulant: whether or not all linear operators are circulant.
implementation: how to implement fast diagonalization. Default uses 'rfft'
for grid size larger than 1024 and 'matmul' otherwise:
- 'matmul': scales like O(N**(d+1)) for d N-dimensional operators, but
makes good use of matmul hardware. Requires hermitian=True.
- 'fft': scales like O(N**d * log(N)) for d N-dimensional operators.
Requires circulant=True.
- 'rfft': use the RFFT instead of the FFT. This is a little faster than
'fft' but also has slightly larger error. It currently requires an even
sized last axis and circulant=True.
precision: numerical precision for matrix multplication. Only relevant on
TPUs with implementation='matmul'.
Returns:
A function that computes the transformation of the indicated operator.
References:
[1] Lynch, R. E., Rice, J. R. & Thomas, D. H. Direct solution of partial
difference equations by tensor product methods. Numer. Math. 6, 185–199
(1964). https://paperpile.com/app/p/b7fdea4e-b2f7-0ada-b056-a282325c3ecf
"""
if any(op.ndim != 2 or op.shape[0] != op.shape[1] for op in operators):
raise ValueError('operators are not all square matrices. Shapes are '
+ ', '.join(str(op.shape) for op in operators))
if implementation is None:
if all(device.platform == 'tpu' for device in jax.local_devices()):
size = max(op.shape[0] for op in operators)
implementation = 'rfft' if size > 1024 else 'matmul'
else:
implementation = 'rfft'
if implementation == 'rfft' and operators[-1].shape[0] % 2:
implementation = 'matmul'
if implementation == 'matmul':
if not hermitian:
raise ValueError('non-hermitian operators not yet supported with '
'implementation="matmul"')
return _hermitian_matmul_transform(func, operators, dtype, precision)
elif implementation == 'fft':
if not circulant:
raise ValueError('non-circulant operators not yet supported with '
'implementation="fft"')
return _circulant_fft_transform(func, operators, dtype)
elif implementation == 'rfft':
if not circulant:
raise ValueError('non-circulant operators not yet supported with '
'implementation="rfft"')
return _circulant_rfft_transform(func, operators, dtype)
else:
raise ValueError(f'invalid implementation: {implementation}')
def _hermitian_matmul_transform(
func: Callable[[Array], Array],
operators: Sequence[np.ndarray],
dtype: np.dtype,
precision: lax.Precision = lax.Precision.HIGHEST,
) -> Callable[[Array], Array]:
"""Fast diagonalization by matrix multiplication along each axis."""
eigenvalues, eigenvectors = zip(*map(np.linalg.eigh, operators))
# Example: if eigenvalues=[a, b, c], then:
# summed_eigenvalues[i, j, k] == a[i] + b[j] + c[k]
# for all i, j, k.
summed_eigenvalues = functools.reduce(np.add.outer, eigenvalues)
diagonals = jnp.asarray(func(summed_eigenvalues), dtype)
eigenvectors = [jnp.asarray(vector, dtype) for vector in eigenvectors]
shape = summed_eigenvalues.shape
if diagonals.shape != shape:
raise ValueError('output shape from func() does not match input shape: '
f'{diagonals.shape} vs {shape}')
def apply(rhs: Array) -> Array:
if rhs.shape != shape:
raise ValueError(f'rhs.shape={rhs.shape} does not match shape={shape}')
if rhs.dtype != dtype:
raise ValueError(f'rhs.dtype={rhs.dtype} does not match dtype={dtype}')
# Use tensordot so we have more control over the underlying XLA operations.
out = rhs
for vectors in eigenvectors:
out = jnp.tensordot(out, vectors, axes=(0, 0), precision=precision)
out *= diagonals
for vectors in eigenvectors:
out = jnp.tensordot(out, vectors, axes=(0, 1), precision=precision)
return out
return apply
def _circulant_fft_transform(
func: Callable[[Array], Array],
operators: Sequence[np.ndarray],
dtype: np.dtype,
) -> Callable[[Array], Array]:
"""Fast diagonalization by Fast Fourier Transform."""
# https://en.wikipedia.org/wiki/Circulant_matrix#Eigenvectors_and_eigenvalues
eigenvalues = [np.fft.fft(op[:, 0]) for op in operators]
summed_eigenvalues = functools.reduce(np.add.outer, eigenvalues)
diagonals = jnp.asarray(func(summed_eigenvalues))
shape = tuple(op.shape[0] for op in operators)
if diagonals.shape != shape:
raise ValueError('output shape from func() does not match input shape: '
f'{diagonals.shape} vs {shape}')
def apply(rhs: Array) -> Array:
if rhs.shape != shape:
raise ValueError(f'rhs.shape={rhs.shape} does not match shape={shape}')
return jnp.fft.ifftn(diagonals * jnp.fft.fftn(rhs)).astype(dtype)
return apply
def _circulant_rfft_transform(
func: Callable[[Array], Array],
operators: Sequence[np.ndarray],
dtype: np.dtype,
) -> Callable[[Array], Array]:
"""Fast diagonalization by real-valued Fast Fourier Transform."""
# https://en.wikipedia.org/wiki/Circulant_matrix#Eigenvectors_and_eigenvalues
if operators[-1].shape[0] % 2:
raise ValueError('implementation="rfft" currently requires an even size '
'for the last axis')
# Use `rfft()` only on the last operator so the shape of `diagonals` matches
# the shape of the output from `rfftn()` without any extra wrangling.
eigenvalues = ([np.fft.fft(op[:, 0]) for op in operators[:-1]]
+ [np.fft.rfft(operators[-1][:, 0])])
summed_eigenvalues = functools.reduce(np.add.outer, eigenvalues)
diagonals = jnp.asarray(func(summed_eigenvalues))
if diagonals.shape != summed_eigenvalues.shape:
raise ValueError('output shape from func() does not match input shape: '
f'{diagonals.shape} vs {summed_eigenvalues.shape}')
def apply(rhs: Array) -> Array:
if rhs.dtype != dtype:
raise ValueError(f'rhs.dtype={rhs.dtype} does not match dtype={dtype}')
return jnp.fft.irfftn(diagonals * jnp.fft.rfftn(rhs)).astype(dtype)
return apply
def pseudoinverse(
operators: Sequence[np.ndarray],
dtype: np.dtype,
*,
hermitian: bool = False,
circulant: bool = False,
implementation: Optional[str] = None,
precision: lax.Precision = lax.Precision.HIGHEST,
cutoff: Optional[float] = None,
) -> Callable[[Array], Array]:
"""Invert a linear operator written as a sum of operators on each axis.
Args:
operators: forward linear operators as matrices, applied along each axis.
Each of these matrices is diagonalized.
dtype: dtype of the right-hand-side.
hermitian: whether or not all linear operator are Hermitian (i.e., symmetric
in the real valued case).
circulant: whether or not all linear operators are circulant.
implementation: how to implement fast diagonalization.
precision: numerical precision for matrix multplication. Only relevant on
TPUs.
cutoff: eigenvalues with absolute value smaller than this number are
discarded rather than being inverted. By default, uses 10 times floating
point epsilon.
Returns:
A function that computes the pseudo-inverse of the indicated operator.
"""
if cutoff is None:
cutoff = 10 * jnp.finfo(dtype).eps
def func(v):
with np.errstate(divide='ignore', invalid='ignore'):
return np.where(abs(v) > cutoff, 1 / v, 0)
return transform(func, operators, dtype, hermitian=hermitian,
circulant=circulant, implementation=implementation,
precision=precision)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment