Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for functionality related to advection."""
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import interpolation
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = interpolation.InterpolationFn
# TODO(dkochkov) Consider testing if we need operator splitting methods.
def _advect_aligned(cs: GridVariableVector, v: GridVariableVector) -> GridArray:
"""Computes fluxes and the associated advection for aligned `cs` and `v`.
The values `cs` should consist of a single quantity `c` that has been
interpolated to the offset of the components of `v`. The components of `v` and
`cs` should be located at the faces of a single (possibly offset) grid cell.
We compute the advection as the divergence of the flux on this control volume.
The boundary condition on the flux is inherited from the scalar quantity `c`.
A typical example in three dimensions would have
```
cs[0].offset == v[0].offset == (1., .5, .5)
cs[1].offset == v[1].offset == (.5, 1., .5)
cs[2].offset == v[2].offset == (.5, .5, 1.)
```
In this case, the returned advection term would have offset `(.5, .5, .5)`.
Args:
cs: a sequence of `GridArray`s; a single value `c` that has been
interpolated so that it is aligned with each component of `v`.
v: a sequence of `GridArrays` describing a velocity field. Should be defined
on the same Grid as cs.
Returns:
An `GridArray` containing the time derivative of `c` due to advection by
`v`.
Raises:
ValueError: `cs` and `v` have different numbers of components.
AlignmentError: if the components of `cs` are not aligned with those of `v`.
"""
# TODO(jamieas): add more sophisticated alignment checks, ensuring that the
# values are located on the faces of a control volume.
if len(cs) != len(v):
raise ValueError('`cs` and `v` must have the same length;'
f'got {len(cs)} vs. {len(v)}.')
flux = tuple(c.array * u.array for c, u in zip(cs, v))
bcs = tuple(
boundaries.get_advection_flux_bc_from_velocity_and_scalar(v[i], cs[i], i)
for i in range(len(v)))
flux = tuple(bc.impose_bc(f) for f, bc in zip(flux, bcs))
return -fd.divergence(flux)
def advect_general(
c: GridVariable,
v: GridVariableVector,
u_interpolation_fn: InterpolationFn,
c_interpolation_fn: InterpolationFn,
dt: Optional[float] = None) -> GridArray:
"""Computes advection of a scalar quantity `c` by the velocity field `v`.
This function follows the following procedure:
1. Interpolate each component of `v` to the corresponding face of the
control volume centered on `c`.
2. Interpolate `c` to the same control volume faces.
3. Compute the flux `cu` using the aligned values.
4. Set the boundary condition on flux, which is inhereited from `c`.
5. Return the negative divergence of the flux.
Args:
c: the quantity to be transported.
v: a velocity field. Should be defined on the same Grid as c.
u_interpolation_fn: method for interpolating velocity field `v`.
c_interpolation_fn: method for interpolating scalar field `c`.
dt: unused time-step.
Returns:
The time derivative of `c` due to advection by `v`.
"""
if not boundaries.has_all_periodic_boundary_conditions(c):
raise NotImplementedError(
'Non-periodic boundary conditions are not implemented.')
target_offsets = grids.control_volume_offsets(c)
aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt)
for u, target_offset in zip(v, target_offsets))
aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
for target_offset in target_offsets)
return _advect_aligned(aligned_c, aligned_v)
def advect_linear(c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None) -> GridArray:
"""Computes advection using linear interpolations."""
return advect_general(c, v, interpolation.linear, interpolation.linear, dt)
def advect_upwind(c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None) -> GridArray:
"""Computes advection using first-order upwind interpolation on `c`."""
return advect_general(c, v, interpolation.linear, interpolation.upwind, dt)
def _align_velocities(v: GridVariableVector) -> Tuple[GridVariableVector]:
"""Returns interpolated components of `v` needed for convection.
Args:
v: a velocity field.
Returns:
A d-tuple of d-tuples of `GridVariable`s `aligned_v`, where `d = len(v)`.
The entry `aligned_v[i][j]` is the component `v[i]` after interpolation to
the appropriate face of the control volume centered around `v[j]`.
"""
grid = grids.consistent_grid(*v)
offsets = tuple(grids.control_volume_offsets(u) for u in v)
aligned_v = tuple(
tuple(interpolation.linear(v[i], offsets[i][j])
for j in range(grid.ndim))
for i in range(grid.ndim))
return aligned_v
def _velocities_to_flux(
aligned_v: Tuple[GridVariableVector]) -> Tuple[GridVariableVector]:
"""Computes the fluxes across the control volume faces for a velocity field.
This is the flux associated with the nonlinear term `vv` for velocity `v`.
The boundary condition on the flux is inherited from `v`.
Args:
aligned_v: a d-tuple of d-tuples of `GridVariable`s such that the entry
`aligned_v[i][j]` is the component `v[i]` after interpolation to
the appropriate face of the control volume centered around `v[j]`. This is
the output of `_align_velocities`.
Returns:
A tuple of tuples `flux` of `GridVariable`s with the same structure as
`aligned_v`. The entry `flux[i][j]` is `aligned_v[i][j] * aligned_v[j][i]`.
"""
ndim = len(aligned_v)
flux = [tuple() for _ in range(ndim)]
for i in range(ndim):
for j in range(ndim):
if i <= j:
bc = boundaries.get_advection_flux_bc_from_velocity_and_scalar(
aligned_v[j][i], aligned_v[i][j], j)
flux[i] += (bc.impose_bc(aligned_v[i][j].array *
aligned_v[j][i].array),)
else:
flux[i] += (flux[j][i],)
return tuple(flux)
def convect_linear(v: GridVariableVector) -> GridArrayVector:
"""Computes convection/self-advection of the velocity field `v`.
This function is conceptually equivalent to
```
def convect_linear(v, grid):
return tuple(advect_linear(u, v, grid) for u in v)
```
However, there are several optimizations to avoid duplicate operations.
Args:
v: a velocity field.
Returns:
A tuple containing the time derivative of each component of `v` due to
convection.
"""
# TODO(jamieas): consider a more efficient vectorization of this function.
# TODO(jamieas): incorporate variable density.
aligned_v = _align_velocities(v)
fluxes = _velocities_to_flux(aligned_v)
return tuple(-fd.divergence(flux) for flux in fluxes)
def advect_van_leer(
c: GridVariable,
v: GridVariableVector,
dt: float,
mode: str = boundaries.Padding.MIRROR,
) -> GridArray:
"""Computes advection of a scalar quantity `c` by the velocity field `v`.
Implements Van-Leer flux limiting scheme that uses second order accurate
approximation of fluxes for smooth regions of the solution. This scheme is
total variation diminishing (TVD). For regions with high gradients flux
limitor transformes the scheme into a first order method. For [1] for
reference. This function follows the following procedure:
1. Shifts c to offset < 1 if necessary.
2. Scalar c now has a well defined right-hand (upwind) value.
3. Computes upwind flux for each direction.
4. Computes van leer flux limiter:
a. Use the shifted c to interpolate each component of `v` to the
right-hand (upwind) face of the control volume centered on `c`.
b. Compute the ratio of successive gradients:
In nonperiodic case, the value outside the boundary is not defined.
Mode is used to interpolate past the boundary.
c. Compute flux limiter function.
d. Computes higher order flux correction.
5. Combines fluxes and assigns flux boundary condition.
6. Computes the negative divergence of fluxes.
7. Shifts the computed values back to original offset of c.
Args:
c: the quantity to be transported.
v: a velocity field. Should be defined on the same Grid as c.
dt: time step for which this scheme is TVD and second order accurate
in time.
mode: For non-periodic BC, specifies extrapolation of values beyond the
boundary, which is used by nonlinear interpolation.
Returns:
The time derivative of `c` due to advection by `v`.
#### References
[1]: MIT 18.336 spring 2009 Finite Volume Methods Lecture 19.
go/mit-18.336-finite_volume_methods-19
[2]:
www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf
"""
# TODO(dkochkov) reimplement this using apply_limiter method.
c_left_var = c
# if the offset is 1., shift by 1 to offset 0.
# otherwise c_right is not defined.
for ax in range(c.grid.ndim):
# int(c.offset[ax] % 1 - c.offset[ax]) = -1 if c.offset[ax] is 1 else
# int(c.offset[ax] % 1 - c.offset[ax]) = 0.
# i.e. this shifts the 1 aligned data to 0 offset, the rest is unchanged.
c_left_var = c.bc.impose_bc(
c_left_var.shift(int(c.offset[ax] % 1 - c.offset[ax]), axis=ax))
offsets = grids.control_volume_offsets(c_left_var)
# if c offset is 0, aligned_v is at 0.5.
# if c offset is at .5, aligned_v is at 1.
aligned_v = tuple(interpolation.linear(u, offset)
for u, offset in zip(v, offsets))
flux = []
# Assign flux boundary condition
flux_bc = [
boundaries.get_advection_flux_bc_from_velocity_and_scalar(u, c, direction)
for direction, u in enumerate(v)
]
# first, compute upwind flux.
for axis, u in enumerate(aligned_v):
c_center = c_left_var.data
# by shifting c_left + 1, c_right is well-defined.
c_right = c_left_var.shift(+1, axis=axis).data
upwind_flux = grids.applied(jnp.where)(
u.array > 0, u.array * c_center, u.array * c_right)
flux.append(upwind_flux)
# next, compute van_leer correction.
for axis, (u, h) in enumerate(zip(aligned_v, c.grid.step)):
u = u.bc.shift(u.array, int(u.offset[axis] % 1 - u.offset[axis]), axis=axis)
# c is put to offset .5 or 1.
c_center_arr = c.shift(int(1 - c.offset[ax]), axis=ax)
# if c offset is 1, u offset is .5.
# if c offset is .5, u offset is 0.
# u_i is always on the left of c_center_var_i
c_center = c_center_arr.data
# shift -1 are well defined now
# shift +1 is not well defined for c offset 1 because then c(wall + 1) is
# not defined.
# However, the flux that uses c(wall + 1) offset gets overridden anyways
# when flux boundary condition is overridden.
# Thus, any mode can be used here.
c_right = c.bc.shift(c_center_arr, +1, axis=axis, mode=mode).data
c_left = c.bc.shift(c_center_arr, -1, axis=axis).data
# shift -2 is tricky:
# It is well defined if c is periodic.
# Else, c(-1) or c(-1.5) are not defined.
# Then, mode is used to interpolate the values.
c_left_left = c.bc.shift(
c_center_arr, -2, axis, mode=mode).data
numerator_positive = c_left - c_left_left
numerator_negative = c_right - c_center
numerator = grids.applied(jnp.where)(u > 0, numerator_positive,
numerator_negative)
denominator = grids.GridArray(c_center - c_left, u.offset, u.grid)
# We want to calculate denominator / (abs(denominator) + abs(numerator))
# To make it differentiable, it needs to be done in stages.
# ensures that there is no division by 0
phi_van_leer_denominator_avoid_nans = grids.applied(jnp.where)(
abs(denominator) > 0, (abs(denominator) + abs(numerator)), 1.)
phi_van_leer_denominator_inv = denominator / phi_van_leer_denominator_avoid_nans
phi_van_leer = numerator * (grids.applied(jnp.sign)(denominator) +
grids.applied(jnp.sign)
(numerator)) * phi_van_leer_denominator_inv
abs_velocity = abs(u)
courant_numbers = (dt / h) * abs_velocity
pre_factor = 0.5 * (1 - courant_numbers) * abs_velocity
flux_correction = pre_factor * phi_van_leer
# Shift back onto original offset.
flux_correction = flux_bc[axis].shift(
flux_correction, int(offsets[axis][axis] - u.offset[axis]), axis=axis)
flux[axis] += flux_correction
flux = tuple(flux_bc[axis].impose_bc(f) for axis, f in enumerate(flux))
advection = -fd.divergence(flux)
# shift the variable back onto the original offset
for ax in range(c.grid.ndim):
advection = c.bc.shift(
advection, -int(c.offset[ax] % 1 - c.offset[ax]), axis=ax)
return advection
def advect_step_semilagrangian(
c: GridVariable,
v: GridVariableVector,
dt: float
) -> GridVariable:
"""Semi-Lagrangian advection of a scalar quantity.
Note that unlike the other advection functions, this function returns values
at the next time-step, not the time derivative.
Args:
c: the quantity to be transported.
v: a velocity field. Should be defined on the same Grid as c.
dt: desired time-step.
Returns:
Advected quantity at the next time-step -- *not* the time derivative.
"""
# Reference: "Learning to control PDEs with Differentiable Physics"
# https://openreview.net/pdf?id=HyeSin4FPB (see Appendix A)
grid = grids.consistent_grid(c, *v)
# TODO(shoyer) Enable lower domains != 0 for this function.
# Hint: indices = [
# -o + (x - l) * n / (u - l)
# for (l, u), o, x, n in zip(grid.domain, c.offset, coords, grid.shape)
# ]
if not all(d[0] == 0 for d in grid.domain):
raise ValueError(
f'Grid domains currently must start at zero. Found {grid.domain}')
coords = [x - dt * interpolation.linear(u, c.offset).data
for x, u in zip(grid.mesh(c.offset), v)]
indices = [x / s - o for s, o, x in zip(grid.step, c.offset, coords)]
if not boundaries.has_all_periodic_boundary_conditions(c):
raise NotImplementedError('non-periodic BCs not yet supported')
c_advected = grids.applied(jax.scipy.ndimage.map_coordinates)(
c.array, indices, order=1, mode='wrap')
return GridVariable(c_advected, c.bc)
# TODO(dkochkov) Implement advect_with_flux_limiter method.
# TODO(dkochkov) Consider moving `advect_van_leer` to test based on performance.
def advect_van_leer_using_limiters(
c: GridVariable,
v: GridVariableVector,
dt: float
) -> GridArray:
"""Implements Van-Leer advection by applying TVD limiter to Lax-Wendroff."""
c_interpolation_fn = interpolation.apply_tvd_limiter(
interpolation.lax_wendroff, limiter=interpolation.van_leer_limiter)
return advect_general(c, v, interpolation.linear, c_interpolation_fn, dt)
def stable_time_step(max_velocity: float,
max_courant_number: float,
grid: grids.Grid) -> float:
"""Calculate a stable time step size for explicit advection.
The calculation is based on the CFL condition for advection.
Args:
max_velocity: maximum velocity.
max_courant_number: the Courant number used to choose the time step. Smaller
numbers will lead to more stable simulations. Typically this should be in
the range [0.5, 1).
grid: a `Grid` object.
Returns:
The prescribed time interval.
"""
dx = min(grid.step)
return max_courant_number * dx / max_velocity
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.advection."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import boundaries
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
def _gaussian_concentration(grid):
offset = tuple(-int(jnp.ceil(s / 2.)) for s in grid.shape)
return grids.GridArray(
jnp.exp(-sum(jnp.square(m) * 30. for m in grid.mesh(offset=offset))),
(0.5,) * len(grid.shape), grid)
def _square_concentration(grid):
select_square = lambda x: jnp.where(jnp.logical_and(x > 0.4, x < 0.6), 1., 0.)
return grids.GridArray(
jnp.array([select_square(m) for m in grid.mesh()]).prod(0),
(0.5,) * len(grid.shape), grid)
def _unit_velocity(grid, velocity_sign=1.):
ndim = grid.ndim
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
return tuple(
grids.GridArray(velocity_sign * jnp.ones(grid.shape) if ax == 0
else jnp.zeros(grid.shape), tuple(offset), grid)
for ax, offset in enumerate(offsets))
def _cos_velocity(grid):
ndim = grid.ndim
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
mesh = grid.mesh()
v = tuple(grids.GridArray(jnp.cos(mesh[i] * 2. * np.pi), tuple(offset), grid)
for i, offset in enumerate(offsets))
return v
def _velocity_implicit(grid, offset, u, t):
"""Returns solution of a Burgers equation at time `t`."""
x = grid.mesh((offset,))[0]
return grids.GridArray(jnp.sin(x - u * t), (offset,), grid)
def _total_variation(array, motion_axis):
next_values = array.shift(1, motion_axis)
variation = jnp.sum(jnp.abs(next_values.data - array.data))
return variation
def _euler_step(advection_method):
def step(c, v, dt):
c_new = c.array + dt * advection_method(c, v, dt)
return c.bc.impose_bc(c_new)
return step
class AdvectionTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='linear_1D',
shape=(101,),
method=_euler_step(advection.advect_linear),
num_steps=1000,
cfl_number=0.01,
atol=5e-2),
dict(testcase_name='linear_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_linear),
num_steps=1000,
cfl_number=0.01,
atol=5e-2),
dict(testcase_name='upwind_1D',
shape=(101,),
method=_euler_step(advection.advect_upwind),
num_steps=100,
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='upwind_3D',
shape=(101, 5, 5),
method=_euler_step(advection.advect_upwind),
num_steps=100,
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_1D_negative_v',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
num_steps=100,
cfl_number=0.5,
atol=2e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 5, 5),
method=_euler_step(advection.advect_van_leer),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_using_limiters_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer_using_limiters),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='van_leer_using_limiters_3D',
shape=(101, 5, 5),
method=_euler_step(advection.advect_van_leer_using_limiters),
num_steps=100,
cfl_number=0.5,
atol=2e-2),
dict(testcase_name='semilagrangian_1D',
shape=(101,),
method=advection.advect_step_semilagrangian,
num_steps=100,
cfl_number=0.5,
atol=7e-2),
dict(testcase_name='semilagrangian_3D',
shape=(101, 5, 5),
method=advection.advect_step_semilagrangian,
num_steps=100,
cfl_number=0.5,
atol=7e-2),
)
def test_advection_analytical(
self, shape, method, num_steps, cfl_number, atol, v_sign=1):
"""Tests advection of a Gaussian concentration on a periodic grid."""
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_gaussian_concentration(grid), bc)
dt = cfl_number * min(step)
advect = functools.partial(method, v=v, dt=dt)
evolve = jax.jit(funcutils.repeated(advect, num_steps))
ct = evolve(c)
expected_shift = int(round(-cfl_number * num_steps * v_sign))
expected = c.shift(expected_shift, axis=0).data
self.assertAllClose(expected, ct.data, atol=atol)
@parameterized.named_parameters(
dict(
testcase_name='dirichlet_1d_100', shape=(100,), atol=0.001,
offset=.5),
dict(
testcase_name='dirichlet_1d_200',
shape=(200,),
atol=0.00025,
offset=.5),
dict(
testcase_name='dirichlet_1d_400',
shape=(400,),
atol=0.00007,
offset=.5),
dict(
testcase_name='dirichlet_1d_100_cell_edge_0',
shape=(100,),
atol=0.002,
offset=0.),
dict(
testcase_name='dirichlet_1d_200_cell_edge_0',
shape=(200,),
atol=0.0005,
offset=0.),
dict(
testcase_name='dirichlet_1d_400_cell_edge_0',
shape=(400,),
atol=0.000125,
offset=0.),
dict(
testcase_name='dirichlet_1d_100_cell_edge_1',
shape=(100,),
atol=0.002,
offset=1.),
dict(
testcase_name='dirichlet_1d_200_cell_edge_1',
shape=(200,),
atol=0.0005,
offset=1.),
dict(
testcase_name='dirichlet_1d_400_cell_edge_1',
shape=(400,),
atol=0.000125,
offset=1.),
)
def test_burgers_analytical_dirichlet_convergence(
self,
shape,
atol,
offset,
):
num_steps = 1000
cfl_number = 0.01
step = 2 * jnp.pi / 1000
grid = grids.Grid(shape, domain=([0., 2 * jnp.pi],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
v = (bc.impose_bc(_velocity_implicit(grid, offset, 0, 0)),)
dt = cfl_number * step
def _advect(v):
dv_dt = advection.advect_van_leer(c=v[0], v=v, dt=dt) / 2
return (bc.impose_bc(v[0].array + dt * dv_dt),)
evolve = jax.jit(funcutils.repeated(_advect, num_steps))
ct = evolve(v)
expected = bc.impose_bc(
_velocity_implicit(grid, offset, ct[0].data, dt * num_steps)).data
self.assertAllClose(expected, ct[0].data, atol=atol)
@parameterized.named_parameters(
dict(testcase_name='linear_1D',
shape=(101,),
method=_euler_step(advection.advect_linear),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='linear_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_linear),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='upwind_1D',
shape=(101,),
method=_euler_step(advection.advect_upwind),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='upwind_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_upwind),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_1D_negative_v',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_using_limiters_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer_using_limiters),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_using_limiters_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_van_leer_using_limiters),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='semilagrangian_1D',
shape=(101,),
method=advection.advect_step_semilagrangian,
atol=1e-2,
rtol=1e-2),
dict(testcase_name='semilagrangian_3D',
shape=(101, 101, 101),
method=advection.advect_step_semilagrangian,
atol=1e-2,
rtol=1e-2),
)
def test_advection_gradients(
self, shape, method, atol, rtol, cfl_number=0.5, v_sign=1,
):
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_gaussian_concentration(grid), bc)
dt = cfl_number * min(step)
advect = jax.remat(functools.partial(method, v=v, dt=dt))
evolve = jax.jit(funcutils.repeated(advect, steps=10))
def objective(c):
return 0.5 * jnp.sum(evolve(c).data ** 2)
gradient = jax.jit(jax.grad(objective))(c)
self.assertAllClose(c, gradient, atol=atol, rtol=rtol)
@parameterized.named_parameters(
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
dict(testcase_name='van_leer_1D_negative_v',
shape=(101,),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2,
v_sign=-1.),
dict(testcase_name='van_leer_3D',
shape=(101, 101, 101),
method=_euler_step(advection.advect_van_leer),
atol=1e-2,
rtol=1e-2),
)
def test_advection_gradients_division_by_zero(
self, shape, method, atol, rtol, cfl_number=0.5, v_sign=1,
):
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_unit_velocity(grid)[0], bc)
dt = cfl_number * min(step)
advect = jax.remat(functools.partial(method, v=v, dt=dt))
evolve = jax.jit(funcutils.repeated(advect, steps=10))
def objective(c):
return 0.5 * jnp.sum(evolve(c).data ** 2)
gradient = jax.jit(jax.grad(objective))(c)
self.assertAllClose(c, gradient, atol=atol, rtol=rtol)
@parameterized.named_parameters(
dict(testcase_name='_linear_1D',
shape=(101,),
advection_method=advection.advect_linear,
convection_method=advection.convect_linear),
dict(testcase_name='_linear_3D',
shape=(101, 101, 101),
advection_method=advection.advect_linear,
convection_method=advection.convect_linear)
)
def test_convection_vs_advection(
self, shape, advection_method, convection_method,
):
"""Exercises self-advection, check equality with advection on components."""
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _cos_velocity(grid))
self_advected = convection_method(v)
for u, du in zip(v, self_advected):
advected_component = advection_method(u, v)
self.assertAllClose(advected_component, du)
@parameterized.named_parameters(
dict(testcase_name='upwind_1D',
shape=(101,),
method=_euler_step(advection.advect_upwind)),
dict(testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer)),
dict(testcase_name='semilagrangian_1D',
shape=(101,),
method=advection.advect_step_semilagrangian),
)
def test_tvd_property(self, shape, method):
atol = 1e-6
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid))
c = grids.GridVariable(_square_concentration(grid), bc)
dt = min(step) / 100.
num_steps = 300
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_total_variation = _total_variation(c, 0) + atol
for _ in range(num_steps):
ct = advect(ct)
current_total_variation = _total_variation(ct, 0)
self.assertLessEqual(current_total_variation, initial_total_variation)
@parameterized.named_parameters(
dict(
testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer)),
)
def test_mass_conservation(self, shape, method):
offset = 0.5
cfl_number = 0.1
dt = cfl_number / shape[0]
num_steps = 1000
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),))
def u(grid, offset):
x = grid.mesh((offset,))[0]
return grids.GridArray(-jnp.sin(jnp.pi * x), (offset,), grid)
def c0(grid, offset):
x = grid.mesh((offset,))[0]
return grids.GridArray(x, (offset,), grid)
v = (bc.impose_bc(u(grid, 1.)),)
c = c_bc.impose_bc(c0(grid, offset))
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_mass = np.sum(c.data)
for _ in range(num_steps):
ct = advect(ct)
current_total_mass = np.sum(ct.data)
self.assertAllClose(current_total_mass, initial_mass, atol=1e-6)
@parameterized.named_parameters(
dict(testcase_name='van_leers_equivalence_1d',
shape=(101,), v_sign=1.),
dict(testcase_name='van_leers_equivalence_3d',
shape=(101, 101, 101), v_sign=1.),
dict(testcase_name='van_leers_equivalence_1d_negative_v',
shape=(101,), v_sign=-1.),
dict(testcase_name='van_leers_equivalence_3d_negative_v',
shape=(101, 101, 101), v_sign=-1.),
)
def test_van_leer_same_as_van_leer_using_limiters(self, shape, v_sign):
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _unit_velocity(grid, v_sign))
c = grids.GridVariable(_gaussian_concentration(grid), bc)
dt = min(step) / 100.
num_steps = 100
advect_vl = jax.jit(
functools.partial(_euler_step(advection.advect_van_leer), v=v, dt=dt))
advect_vl_using_limiter = jax.jit(
functools.partial(
_euler_step(advection.advect_van_leer_using_limiters), v=v, dt=dt))
c_vl = c
c_vl_using_limiter = c
for _ in range(num_steps):
c_vl = advect_vl(c_vl)
c_vl_using_limiter = advect_vl_using_limiter(c_vl_using_limiter)
self.assertAllClose(c_vl, c_vl_using_limiter, atol=1e-5)
if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility methods for manipulating array-like objects."""
from typing import Any, Callable, List, Tuple, Union
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
import numpy as np
import scipy.linalg
# There is currently no good way to indicate a jax "pytree" with arrays at its
# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more
# information about PyTrees and https://github.com/google/jax/issues/3340 for
# discussion of this issue.
PyTree = Any
Array = Union[np.ndarray, jax.Array]
def _normalize_axis(axis: int, ndim: int) -> int:
"""Validates and returns positive `axis` value."""
if not -ndim <= axis < ndim:
raise ValueError(f'invalid axis {axis} for ndim {ndim}')
if axis < 0:
axis += ndim
return axis
def slice_along_axis(
inputs: PyTree,
axis: int,
idx: Union[slice, int],
expect_same_dims: bool = True
) -> PyTree:
"""Returns slice of `inputs` defined by `idx` along axis `axis`.
Args:
inputs: array or a tuple of arrays to slice.
axis: axis along which to slice the `inputs`.
idx: index or slice along axis `axis` that is returned.
expect_same_dims: whether all arrays should have same number of dimensions.
Returns:
Slice of `inputs` defined by `idx` along axis `axis`.
"""
# arrays, tree_def = jax.tree_util.flatten(inputs)
arrays, tree_def = jax.tree_util.tree_flatten(inputs)
ndims = set(a.ndim for a in arrays)
if expect_same_dims and len(ndims) != 1:
raise ValueError('arrays in `inputs` expected to have same ndims, but have '
f'{ndims}. To allow this, pass expect_same_dims=False')
sliced = []
for array in arrays:
ndim = array.ndim
slc = tuple(idx if j == _normalize_axis(axis, ndim) else slice(None)
for j in range(ndim))
sliced.append(array[slc])
return jax.tree_util.tree_unflatten(tree_def, sliced)
def split_along_axis(
inputs: PyTree,
split_idx: int,
axis: int,
expect_same_dims: bool = True
) -> Tuple[PyTree, PyTree]:
"""Returns a tuple of slices of `inputs` split along `axis` at `split_idx`.
Args:
inputs: pytree of arrays to split.
split_idx: index along `axis` where the second split starts.
axis: axis along which to split the `inputs`.
expect_same_dims: whether all arrays should have same number of dimensions.
Returns:
Tuple of slices of `inputs` split along `axis` at `split_idx`.
"""
first_slice = slice_along_axis(
inputs, axis, slice(0, split_idx), expect_same_dims)
second_slice = slice_along_axis(
inputs, axis, slice(split_idx, None), expect_same_dims)
return first_slice, second_slice
def split_axis(
inputs: PyTree,
axis: int,
keep_dims: bool = False
) -> Tuple[PyTree, ...]:
"""Splits the arrays in `inputs` along `axis`.
Args:
inputs: pytree to be split.
axis: axis along which to split the `inputs`.
keep_dims: whether to keep `axis` dimension.
Returns:
Tuple of pytrees that correspond to slices of `inputs` along `axis`. The
`axis` dimension is removed if `squeeze is set to True.
Raises:
ValueError: if arrays in `inputs` don't have unique size along `axis`.
"""
arrays, tree_def = jax.tree_util.flatten(inputs)
axis_shapes = set(a.shape[axis] for a in arrays)
if len(axis_shapes) != 1:
raise ValueError(f'Arrays must have equal sized axis but got {axis_shapes}')
axis_shape, = axis_shapes
splits = [jnp.split(a, axis_shape, axis=axis) for a in arrays]
if not keep_dims:
splits = jax.tree_util.tree_map(lambda a: jnp.squeeze(a, axis), splits)
splits = zip(*splits)
return tuple(jax.tree_util.unflatten(tree_def, leaves) for leaves in splits)
def concat_along_axis(pytrees, axis):
"""Concatenates `pytrees` along `axis`."""
concat_leaves_fn = lambda *args: jnp.concatenate(args, axis)
return jax.tree_util.tree_map(concat_leaves_fn, *pytrees)
def block_reduce(
array: Array,
block_size: Tuple[int, ...],
reduction_fn: Callable[[Array], Array]
) -> Array:
"""Breaks `array` into `block_size` pieces and applies `f` to each.
This function is equivalent to `scikit-image.measure.block_reduce`:
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.block_reduce
Args:
array: an array.
block_size: the size of the blocks on which the reduction is performed.
Must evenly divide `array.shape`.
reduction_fn: a reduction function that will be applied to each block of
size `block_size`.
Returns:
The result of applying `f` to each block of size `block_size`.
"""
new_shape = []
for b, s in zip(block_size, array.shape):
multiple, residual = divmod(s, b)
if residual != 0:
raise ValueError('`block_size` must divide `array.shape`;'
f'got {block_size}, {array.shape}.')
new_shape += [multiple, b]
multiple_axis_reduction_fn = reduction_fn
for j in reversed(range(array.ndim)):
multiple_axis_reduction_fn = jax.vmap(multiple_axis_reduction_fn, j)
return multiple_axis_reduction_fn(array.reshape(new_shape))
def laplacian_matrix(size: int, step: float) -> np.ndarray:
"""Create 1D Laplacian operator matrix, with periodic BC."""
column = np.zeros(size)
column[0] = -2 / step**2
column[1] = column[-1] = 1 / step**2
return scipy.linalg.circulant(column)
def _laplacian_boundary_dirichlet_cell_centered(laplacians: List[Array],
grid: grids.Grid, axis: int,
side: str) -> None:
"""Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc.
laplacians[i] contains a 3 point stencil matrix L that approximates
d^2/dx_i^2.
For detailed documentation on laplacians input type see
array_utils.laplacian_matrix.
The default return of array_utils.laplacian_matrix makes a matrix for
periodic boundary. For dirichlet boundary, the correct equation is
L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So
laplacian_boundary_dirichlet restricts the matrix L to
interior points only.
This function assumes RHS has cell-centered offset.
Args:
laplacians: list of 1d laplacians
grid: grid object
axis: axis along which to impose dirichlet bc.
side: lower or upper side to assign boundary to.
Returns:
updated list of 1d laplacians.
"""
# This function assumes homogeneous boundary, in which case if the offset
# is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
# 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
if side == 'lower':
laplacians[axis][0, 0] = laplacians[axis][0, 0] - 1 / grid.step[axis]**2
else:
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] - 1 / grid.step[axis]**2
# deletes corner dependencies on the "looped-around" part.
# this should be done irrespective of which side, since one boundary cannot
# be periodic while the other is.
laplacians[axis][0, -1] = 0.0
laplacians[axis][-1, 0] = 0.0
return
def _laplacian_boundary_neumann_cell_centered(laplacians: List[Array],
grid: grids.Grid, axis: int,
side: str) -> None:
"""Converts 1d laplacian matrix to satisfy neumann homogeneous bc.
This function assumes the RHS will have a cell-centered offset.
Neumann boundaries are not defined for edge-aligned offsets elsewhere in the
code.
Args:
laplacians: list of 1d laplacians
grid: grid object
axis: axis along which to impose dirichlet bc.
side: which boundary side to convert to neumann homogeneous bc.
Returns:
updated list of 1d laplacians.
"""
if side == 'lower':
laplacians[axis][0, 0] = laplacians[axis][0, 0] + 1 / grid.step[axis]**2
else:
laplacians[axis][-1, -1] = laplacians[axis][-1, -1] + 1 / grid.step[axis]**2
# deletes corner dependencies on the "looped-around" part.
# this should be done irrespective of which side, since one boundary cannot
# be periodic while the other is.
laplacians[axis][0, -1] = 0.0
laplacians[axis][-1, 0] = 0.0
return
def laplacian_matrix_w_boundaries(
grid: grids.Grid,
offset: Tuple[float, ...],
bc: boundaries.BoundaryConditions,
) -> List[Array]:
"""Returns 1d laplacians that satisfy boundary conditions bc on grid.
Given grid, offset and boundary conditions, returns a list of 1 laplacians
(one along each axis).
Currently, only homogeneous or periodic boundary conditions are supported.
Args:
grid: The grid used to construct the laplacian.
offset: The offset of the variable on which laplacian acts.
bc: the boundary condition of the variable on which the laplacian acts.
Returns:
A list of 1d laplacians.
"""
if not isinstance(bc, boundaries.ConstantBoundaryConditions):
raise NotImplementedError(
f'Explicit laplacians are not implemented for {bc}.')
laplacians = list(map(laplacian_matrix, grid.shape, grid.step))
for axis in range(grid.ndim):
if np.isclose(offset[axis], 0.5):
for i, side in enumerate(['lower', 'upper']): # lower and upper boundary
if bc.types[axis][i] == boundaries.BCType.NEUMANN:
_laplacian_boundary_neumann_cell_centered(
laplacians, grid, axis, side)
elif bc.types[axis][i] == boundaries.BCType.DIRICHLET:
_laplacian_boundary_dirichlet_cell_centered(
laplacians, grid, axis, side)
if np.isclose(offset[axis] % 1, 0.):
if bc.types[axis][0] == boundaries.BCType.DIRICHLET and bc.types[
axis][1] == boundaries.BCType.DIRICHLET:
# This function assumes homogeneous boundary and acts on the interior.
# Thus, the laplacian can be cut off past the edge.
# The interior grid has one fewer grid cell than the actual grid, so
# the size of the laplacian should be reduced.
laplacians[axis] = laplacians[axis][:-1, :-1]
elif boundaries.BCType.NEUMANN in bc.types[axis]:
raise NotImplementedError(
'edge-aligned Neumann boundaries are not implemented.')
return laplacians
def unstack(array, axis):
"""Returns a tuple of slices of `array` along axis `axis`."""
squeeze_fn = lambda x: jnp.squeeze(x, axis=axis)
return tuple(squeeze_fn(x) for x in jnp.split(array, array.shape[axis], axis))
def gram_schmidt_qr(
matrix: Array,
precision: jax.lax.Precision = jax.lax.Precision.HIGHEST
) -> Tuple[Array, Array]:
"""Computes QR decomposition using gramm-schmidt orthogonalization.
This algorithm is suitable for tall matrices with very few columns. This
method is more memory efficient compared to `jnp.linalg.qr`, but is less
numerically stable, especially for matrices with many columns.
Args:
matrix: 2D array representing the matrix to be decomposed into orthogonal
and upper triangular.
precision: numerical precision for matrix multplication. Only relevant on
TPUs.
Returns:
tuple of matrix Q whose columns are orthonormal and R that is upper
triangular.
"""
def orthogonalize(vector, others):
"""Returns the orthogonal component of `vector` with respect to `others`."""
if not others:
return vector / jnp.linalg.norm(vector)
orthogonalize_step = lambda c, x: tuple([c - jnp.dot(c, x) * x, None])
vector, _ = jax.lax.scan(orthogonalize_step, vector, jnp.stack(others))
return vector / jnp.linalg.norm(vector)
num_columns = matrix.shape[1]
columns = unstack(matrix, axis=1)
q_columns = []
r_rows = []
for vec_index, column in enumerate(columns):
next_q_column = orthogonalize(column, q_columns)
r_rows.append(jnp.asarray([
jnp.dot(columns[i], next_q_column) if i >= vec_index else 0.
for i in range(num_columns)]))
q_columns.append(next_q_column)
q = jnp.stack(q_columns, axis=1)
r = jnp.stack(r_rows)
# permute q columns to make entries of r on the diagonal positive.
d = jnp.diag(jnp.sign(jnp.diagonal(r)))
q = jnp.matmul(q, d, precision=precision)
r = jnp.matmul(d, r, precision=precision)
return q, r
def interp1d( # pytype: disable=annotation-type-mismatch # jnp-type
x: Array,
y: Array,
axis: int = -1,
fill_value: Union[str, Array] = jnp.nan,
assume_sorted: bool = True,
) -> Callable[[Array], jax.Array]:
"""Build an interpolation function to approximate `y = f(x)`.
x and y are arrays of values used to approximate some function f: y = f(x).
This returns a function that uses linear interpolation to approximate f
evaluated at new points.
```
x = jnp.linspace(0, 10)
y = jnp.sin(x)
f = interp1d(x, y)
x_new = 1.23
f(x_new)
==> Approximately sin(1.23).
x_new = ... # Shape (4, 5) array
f(x_new)
==> Shape (4, 5) array, approximating sin(x_new).
```
Args:
x: Length N 1-D array of values.
y: Shape (..., N, ...) array of values corresponding to f(x).
axis: Specifies the axis of y along which to interpolate. Interpolation
defaults to the last axis of y.
fill_value: Scalar array or string. If array, this value will be used to
fill in for requested points outside of the data range. If not provided,
then the default is NaN. If "extrapolate", then linear extrapolation is
used. If "constant_extrapolate", then the function is extended as a
constant.
assume_sorted: Whether to assume x is sorted. If True, x must be
monotonically increasing. If False, this function sorts x and reorders
y appropriately.
Returns:
Callable mapping array x_new to values y_new, where
y_new.shape = y.shape[:axis] + x_new.shape + y.shape[axis + 1:]
"""
allowed_fill_value_strs = {'constant_extrapolate', 'extrapolate'}
if isinstance(fill_value, str):
if fill_value not in allowed_fill_value_strs:
raise ValueError(
f'`fill_value` "{fill_value}" not in {allowed_fill_value_strs}')
else:
fill_value = np.asarray(fill_value)
if fill_value.ndim > 0:
raise ValueError(f'Only scalar `fill_value` supported. '
f'Found shape: {fill_value.shape}')
x = jnp.asarray(x)
if x.ndim != 1:
raise ValueError(f'Expected `x` to be 1D. Found shape {x.shape}')
if x.size < 2:
raise ValueError(f'`x` must have at least 2 entries. Found shape {x.shape}')
n_x = x.shape[0]
if not assume_sorted:
ind = jnp.argsort(x)
x = x[ind]
y = jnp.take(y, ind, axis=axis)
y = jnp.asarray(y)
if y.ndim < 1:
raise ValueError(f'Expected `y` to have rank >= 1. Found shape {y.shape}')
if x.size != y.shape[axis]:
raise ValueError(
f'x and y arrays must be equal in length along interpolation axis. '
f'Found x.shape={x.shape} and y.shape={y.shape} and axis={axis}')
axis = _normalize_axis(axis, ndim=y.ndim)
def interp_func(x_new: jax.Array) -> jax.Array:
"""Implementation of the interpolation function."""
x_new = jnp.asarray(x_new)
# We will flatten x_new, then interpolate, then reshape the output.
x_new_shape = x_new.shape
x_new = jnp.ravel(x_new)
# This construction of indices ensures that below_idx < above_idx, even at
# x_new = x[i] exactly for some i.
x_new_clipped = jnp.clip(x_new, np.min(x), np.max(x))
above_idx = jnp.minimum(n_x - 1,
jnp.searchsorted(x, x_new_clipped, side='right'))
below_idx = jnp.maximum(0, above_idx - 1)
def expand(array):
"""Add singletons to rightmost dims of `array` so it bcasts with y."""
array = jnp.asarray(array)
return jnp.reshape(array, array.shape + (1,) * (y.ndim - axis - 1))
x_above = jnp.take(x, above_idx)
x_below = jnp.take(x, below_idx)
y_above = jnp.take(y, above_idx, axis=axis)
y_below = jnp.take(y, below_idx, axis=axis)
slope = (y_above - y_below) / expand(x_above - x_below)
if isinstance(fill_value, str) and fill_value == 'extrapolate':
delta_x = expand(x_new - x_below)
y_new = y_below + delta_x * slope
elif isinstance(fill_value, str) and fill_value == 'constant_extrapolate':
delta_x = expand(x_new_clipped - x_below)
y_new = y_below + delta_x * slope
else: # Else fill_value is an Array.
delta_x = expand(x_new - x_below)
fill_value_ = expand(fill_value)
y_new = y_below + delta_x * slope
y_new = jnp.where(
(delta_x < 0) | (delta_x > expand(x_above - x_below)),
fill_value_, y_new)
return jnp.reshape(
y_new, y_new.shape[:axis] + x_new_shape + y_new.shape[axis + 1:])
return interp_func
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of array utils."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
import scipy.interpolate as spi
import skimage.measure as skm
BCType = boundaries.BCType
class ArrayUtilsTest(test_util.TestCase):
@parameterized.parameters(
dict(array=np.random.RandomState(1234).randn(3, 6, 9),
block_size=(3, 3, 3),
f=jnp.mean),
dict(array=np.random.RandomState(1234).randn(12, 24, 36),
block_size=(6, 6, 6),
f=jnp.max),
dict(array=np.random.RandomState(1234).randn(12, 24, 36),
block_size=(3, 4, 6),
f=jnp.min),
)
def test_block_reduce(self, array, block_size, f):
"""Test `block_reduce` is equivalent to `skimage.measure.block_reduce`."""
expected_output = skm.block_reduce(array, block_size, f)
actual_output = array_utils.block_reduce(array, block_size, f)
self.assertAllClose(expected_output, actual_output, atol=1e-6)
def test_laplacian_matrix(self):
actual = array_utils.laplacian_matrix(4, step=0.5)
expected = 4.0 * np.array(
[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1, -2]])
self.assertAllClose(expected, actual)
@parameterized.parameters(
# Periodic BC
dict(
offset=(0,),
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1,
-2]]),
dict(
offset=(0.5,),
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1,
-2]]),
dict(
offset=(1.,),
bc_types=((BCType.PERIODIC, BCType.PERIODIC),),
expected=[[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1,
-2]]),
# Dirichlet BC
dict(
offset=(0,),
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
expected=[[-2, 1, 0], [1, -2, 1], [0, 1, -2]]),
dict(
offset=(0.5,),
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
expected=[[-3, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1,
-3]]),
dict(
offset=(1.,),
bc_types=((BCType.DIRICHLET, BCType.DIRICHLET),),
expected=[[-2, 1, 0], [1, -2, 1], [0, 1, -2]]),
# Neumann BC
dict(
offset=(0.5,),
bc_types=((BCType.NEUMANN, BCType.NEUMANN),),
expected=[[-1, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1,
-1]]),
# Neumann-Dirichlet BC
dict(
offset=(0.5,),
bc_types=((BCType.NEUMANN, BCType.DIRICHLET),),
expected=[[-1, 1, 0, 0], [1, -2, 1, 0], [0, 1, -2, 1], [0, 0, 1,
-3]]),
)
def test_laplacian_matrix_w_boundaries(self, offset, bc_types, expected):
grid = grids.Grid((4,), step=(.5,))
bc = boundaries.HomogeneousBoundaryConditions(bc_types)
actual = array_utils.laplacian_matrix_w_boundaries(grid, offset, bc)
actual = np.squeeze(actual)
expected = 4.0 * np.array(expected)
self.assertAllClose(expected, actual)
@parameterized.parameters(
dict(matrix=(np.random.RandomState(1234).randn(16, 2))),
dict(matrix=(np.random.RandomState(1234).randn(24, 1))),
dict(matrix=(np.random.RandomState(1234).randn(74, 4))),
)
def test_gram_schmidt_qr(self, matrix):
"""Tests that gram-schmidt_qr is close to numpy for slim matrices."""
q_actual, r_actual = array_utils.gram_schmidt_qr(matrix)
q, r = jnp.linalg.qr(matrix)
# we rearrange the result to make the diagonal of `r` positive.
d = jnp.diag(jnp.sign(jnp.diagonal(r)))
q_expected = q @ d
r_expected = d @ r
self.assertAllClose(q_expected, q_actual, atol=1e-4)
self.assertAllClose(r_expected, r_actual, atol=1e-4)
@parameterized.parameters(
dict(pytree=(np.zeros((6, 3)), np.ones((6, 2, 2))), idx=3, axis=0),
dict(pytree=(np.zeros((3, 8)), np.ones((6, 8))), idx=3, axis=-1),
dict(pytree={'a': np.zeros((3, 9)), 'b': np.ones((6, 9))}, idx=3, axis=1),
dict(pytree=np.zeros((13, 5)), idx=6, axis=0),
dict(pytree=(np.zeros(9), (np.ones((9, 1)), np.ones(9))), idx=6, axis=0),
)
def test_split_and_concat(self, pytree, idx, axis):
"""Tests that split_along_axis, concat_along_axis return expected shapes."""
split_a, split_b = array_utils.split_along_axis(pytree, idx, axis, False)
with self.subTest('split_shape'):
self.assertEqual(jax.tree_util.leaves(split_a)[0].shape[axis], idx)
reconstruction = array_utils.concat_along_axis([split_a, split_b], axis)
with self.subTest('split_concat_roundtrip_structure'):
actual_tree_def = jax.tree_util.structure(reconstruction)
expected_tree_def = jax.tree_util.structure(pytree)
self.assertSameStructure(actual_tree_def, expected_tree_def)
actual_values = jax.tree_util.leaves(reconstruction)
expected_values = jax.tree_util.leaves(pytree)
with self.subTest('split_concat_roundtrip_values'):
for actual, expected in zip(actual_values, expected_values):
self.assertAllClose(actual, expected)
same_ndims = len(set(a.ndim for a in actual_values)) == 1
if not same_ndims:
with self.subTest('raises_when_wrong_ndims'):
with self.assertRaisesRegex(ValueError, 'arrays in `inputs` expected'):
split_a, split_b = array_utils.split_along_axis(pytree, idx, axis)
with self.subTest('multiple_concat_shape'):
arrays = [split_a, split_a, split_b, split_b]
double_concat = array_utils.concat_along_axis(arrays, axis)
actual_shape = jax.tree_util.leaves(double_concat)[0].shape[axis]
expected_shape = jax.tree_util.leaves(pytree)[0].shape[axis] * 2
self.assertEqual(actual_shape, expected_shape)
@parameterized.parameters(
dict(pytree=(np.zeros((6, 3)), np.ones((6, 2, 2))), axis=0),
dict(pytree={'a': np.zeros((3, 9)), 'b': np.ones((6, 9))}, axis=1),
dict(pytree=np.zeros((13, 5)), axis=0),
dict(pytree=(np.zeros(9), (np.ones((9, 1)), np.ones(9))), axis=0),
)
def test_split_along_axis_shapes(self, pytree, axis):
with self.subTest('with_keep_dims'):
splits = array_utils.split_axis(pytree, axis, keep_dims=True)
get_expected_shape = lambda x: x.shape[:axis] + (1,) + x.shape[axis + 1:]
expected_shapes = jax.tree_util.tree_map(get_expected_shape, pytree)
for split in splits:
actual = jax.tree_util.tree_map(lambda x: x.shape, split)
self.assertEqual(expected_shapes, actual)
with self.subTest('without_keep_dims'):
splits = array_utils.split_axis(pytree, axis, keep_dims=False)
get_expected_shape = lambda x: x.shape[:axis] + x.shape[axis + 1:]
expected_shapes = jax.tree_util.tree_map(get_expected_shape, pytree)
for split in splits:
actual = jax.tree_util.tree_map(lambda x: x.shape, split)
self.assertEqual(expected_shapes, actual)
class Interp1DTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='FillWithExtrapolate', fill_value='extrapolate'),
dict(testcase_name='FillWithScalar', fill_value=123),
)
def test_same_as_scipy_on_scalars_and_check_grads(self, fill_value):
rng = np.random.RandomState(45)
n = 10
y = rng.randn(n)
x_low = 5.
x_high = 9.
x = np.linspace(x_low, x_high, num=n)
sp_func = spi.interp1d(
x, y, kind='linear', fill_value=fill_value, bounds_error=False)
cfd_func = array_utils.interp1d(x, y, fill_value=fill_value)
# Check x_new at the table definition points `x`, points outside, and points
# in between.
x_to_check = np.concatenate(([x_low - 1], x, x * 1.051, [x_high + 1]))
for x_new in x_to_check:
sp_y_new = sp_func(x_new).astype(np.float32)
cfd_y_new = cfd_func(x_new)
self.assertAllClose(sp_y_new, cfd_y_new, rtol=1e-5, atol=1e-6)
grad_cfd_y_new = jax.grad(cfd_func)(x_new)
# Gradients should be nonzero except when outside the range of `x` and
# filling with a constant.
# Why check this? Because, some indexing methods result in gradients == 0
# at the interpolation table points.
if fill_value == 'extrapolate' or x_low <= x_new <= x_high:
self.assertLess(0, np.abs(grad_cfd_y_new))
else:
self.assertTrue(np.isfinite(grad_cfd_y_new))
@parameterized.named_parameters(
dict(
testcase_name='TableIs1D_XNewIs1D_Axis0_FillWithScalar',
table_ndim=1,
x_new_ndim=1,
axis=0,
fill_value=12345,
),
dict(
testcase_name='TableIs1D_XNewIs1D_Axis0_FillWithScalar_NoAssumeSorted',
table_ndim=1,
x_new_ndim=1,
axis=0,
fill_value=12345,
assume_sorted=False,
),
dict(
testcase_name='TableIs1D_XNewIs1D_Axis0_FillWithExtrapolate',
table_ndim=1,
x_new_ndim=1,
axis=0,
fill_value='extrapolate',
),
dict(
testcase_name='TableIs2D_XNewIs1D_Axis0_FillWithExtrapolate',
table_ndim=2,
x_new_ndim=1,
axis=0,
fill_value='extrapolate',
),
dict(
testcase_name=(
'TableIs2D_XNewIs1D_Axis0_FillWithExtrapolate_NoAssumeSorted'),
table_ndim=2,
x_new_ndim=1,
axis=0,
fill_value='extrapolate',
assume_sorted=False,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axis1_FillWithScalar',
table_ndim=3,
x_new_ndim=2,
axis=1,
fill_value=12345,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn1_FillWithScalar',
table_ndim=3,
x_new_ndim=2,
axis=-1,
fill_value=12345,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn1_FillWithExtrapolate',
table_ndim=3,
x_new_ndim=2,
axis=-1,
fill_value='extrapolate',
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn3_FillWithScalar',
table_ndim=3,
x_new_ndim=2,
axis=-3,
fill_value=1234,
),
dict(
testcase_name='TableIs3D_XNewIs2D_Axisn3_FillWithConstantExtrapolate',
table_ndim=3,
x_new_ndim=2,
axis=-3,
fill_value='constant_extrapolate',
),
dict(
testcase_name=(
'Table3D_XNew2D_Axisn3_FillConstantExtrapolate_NoAssumeSorted'),
table_ndim=3,
x_new_ndim=2,
axis=-3,
fill_value='constant_extrapolate',
assume_sorted=False,
),
)
def test_same_as_scipy_on_arrays(
self,
table_ndim,
x_new_ndim,
axis,
fill_value,
assume_sorted=True,
):
"""Test results are the same as scipy.interpolate.interp1d."""
rng = np.random.RandomState(45)
# Arbitrary shape that ensures all dims are different to prevent
# broadcasting from hiding bugs.
y_shape = tuple(range(5, 5 + table_ndim))
y = rng.randn(*y_shape)
n = y_shape[axis]
x_low = 5
x_high = 9
x = np.linspace(x_low, x_high, num=n)**2 # Arbitrary non-linearly spaced x
if not assume_sorted:
rng.shuffle(x)
# Scipy doesn't have 'constant_extrapolate', so treat it special.
# Here use np.nan as the fill value, which will be easy to spot if we handle
# it wrong.
if fill_value == 'constant_extrapolate':
sp_fill_value = np.nan
else:
sp_fill_value = fill_value
sp_func = spi.interp1d(
x,
y,
kind='linear',
axis=axis,
fill_value=sp_fill_value,
bounds_error=False,
assume_sorted=assume_sorted,
)
cfd_func = array_utils.interp1d(
x, y, axis=axis, fill_value=fill_value, assume_sorted=assume_sorted)
# Make n_x_new > n, so we can selectively fill values as below.
n_x_new = max(2 * n, 20)
# Make x_new over the same range as x.
x_new_shape = tuple(range(2, x_new_ndim + 1)) + (n_x_new,)
x_new = (x_low + rng.rand(*x_new_shape) * (x_high - x_low))**2
x_new[..., 0] = np.min(x) - 1 # Out of bounds low
x_new[..., -1] = np.max(x) + 1 # Out of bounds high
x_new[..., 1:n + 1] = x # All the grid points
# Scipy doesn't have the 'constant_extrapolate' feature, but
# constant_extrapolate is achieved by clipping the input.
if fill_value == 'constant_extrapolate':
sp_x_new = np.clip(x_new, np.min(x), np.max(x))
else:
sp_x_new = x_new
sp_y_new = sp_func(sp_x_new).astype(np.float32)
cfd_y_new = cfd_func(x_new)
self.assertAllClose(sp_y_new, cfd_y_new, rtol=1e-6, atol=1e-6)
if __name__ == '__main__':
absltest.main()
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(pnorgaard) Implement bicgstab for non-symmetric operators
"""Module for functionality related to diffusion."""
from typing import Optional, Tuple
import jax.numpy as jnp
import jax.scipy.sparse.linalg
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
Array = grids.Array
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
def diffuse(c: GridVariable, nu: float) -> GridArray:
"""Returns the rate of change in a concentration `c` due to diffusion."""
return nu * fd.laplacian(c)
def stable_time_step(viscosity: float, grid: grids.Grid) -> float:
"""Calculate a stable time step size for explicit diffusion.
The calculation is based on analysis of central-time-central-space (CTCS)
schemes.
Args:
viscosity: kinematic visosity
grid: a `Grid` object.
Returns:
The prescribed time interval.
"""
if viscosity == 0:
return float('inf')
dx = min(grid.step)
ndim = grid.ndim
return dx ** 2 / (viscosity * 2 ** ndim)
def _subtract_linear_part_dirichlet(
c_data: Array,
grid: grids.Grid,
axis: int,
offset: Tuple[float, float],
bc_values: Tuple[float, float],
) -> Array:
"""Transforms c_data such that c_data satisfies dirichlet boundary.
The function subtracts a linear function from c_data s.t. the returned
array has homogeneous dirichlet boundaries. Note that this assumes c_data has
constant dirichlet boundary values.
Args:
c_data: right-hand-side of diffusion equation.
grid: grid object
axis: axis along which to impose boundary transformation
offset: offset of the right-hand-side
bc_values: boundary values along axis
Returns:
transformed right-hand-side
"""
def _update_rhs_along_axis(arr_1d, linear_part):
arr_1d = arr_1d - linear_part
return arr_1d
lower_value, upper_value = bc_values
y = grid.mesh(offset)[axis][0]
one_d_grid = grids.Grid((grid.shape[axis],), domain=(grid.domain[axis],))
y_boundary = boundaries.dirichlet_boundary_conditions(ndim=1)
y = y_boundary.trim_boundary(grids.GridArray(y, (offset[axis],),
one_d_grid)).data
domain_length = (grid.domain[axis][1] - grid.domain[axis][0])
domain_start = grid.domain[axis][0]
linear_part = lower_value + (upper_value - lower_value) * (
y - domain_start) / domain_length
c_data = jnp.apply_along_axis(
_update_rhs_along_axis, axis, c_data, linear_part)
return c_data
def _rhs_transform(
u: grids.GridArray,
bc: boundaries.BoundaryConditions,
) -> Array:
"""Transforms the RHS of diffusion equation.
In case of constant dirichlet boundary conditions for heat equation
the linear term is subtracted. See diffusion.solve_fast_diag.
Args:
u: a GridArray that solves ∇²x = ∇²u for x.
bc: specifies boundary of u.
Returns:
u' s.t. u = u' + w where u' has 0 dirichlet bc and w is linear.
"""
if not isinstance(bc, boundaries.ConstantBoundaryConditions):
raise NotImplementedError(
f'transformation cannot be done for this {bc}.')
u_data = u.data
for axis in range(u.grid.ndim):
for i, _ in enumerate(['lower', 'upper']): # lower and upper boundary
if bc.types[axis][i] == boundaries.BCType.DIRICHLET:
bc_values = [0., 0.]
bc_values[i] = bc.bc_values[axis][i]
u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset,
bc_values)
elif bc.types[axis][i] == boundaries.BCType.NEUMANN:
if any(bc.bc_values[axis]):
raise NotImplementedError(
'transformation is not implemented for inhomogeneous Neumann bc.')
return u_data
def solve_cg(v: GridVariableVector,
nu: float,
dt: float,
rtol: float = 1e-6,
atol: float = 1e-6,
maxiter: Optional[int] = None) -> GridVariableVector:
"""Conjugate gradient solve for diffusion."""
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError('solve_cg() expects periodic BC')
def solve_component(u: GridVariable) -> GridArray:
"""Solves (1 - ν Δt ∇²) u_{t+1} = u_{tilda} for u_{t+1}."""
def linear_op(u_new: GridArray) -> GridArray:
"""Linear operator for (1 - ν Δt ∇²) u_{t+1}."""
u_new = grids.GridVariable(u_new, u.bc) # get boundary condition from u
return u_new.array - dt * nu * fd.laplacian(u_new)
def cg(b: GridArray, x0: GridArray) -> GridArray:
"""Iteratively solves Lx = b. with initial guess x0."""
x, _ = jax.scipy.sparse.linalg.cg(
linear_op, b, x0=x0, tol=rtol, atol=atol, maxiter=maxiter)
return x
return cg(u.array, u.array)
return tuple(grids.GridVariable(solve_component(u), u.bc) for u in v)
def solve_fast_diag(
v: GridVariableVector,
nu: float,
dt: float,
implementation: Optional[str] = None,
) -> GridVariableVector:
"""Solve for diffusion using the fast diagonalization approach."""
# We reuse eigenvectors from the Laplacian and transform the eigenvalues
# because this is better conditioned than directly diagonalizing 1 - ν Δt ∇²
# when ν Δt is small.
def func(x):
dt_nu_x = (dt * nu) * x
return dt_nu_x / (1 - dt_nu_x)
# Compute (1 - ν Δt ∇²)⁻¹ u as u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u, for less
# error when ν Δt is small.
# If dirichlet bc are supplied: only works for dirichlet bc that are linear
# functions on the boundary. Then u = u' + w where u' has 0 dirichlet bc and
# w is linear. Then u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u = u +
# (1 - ν Δt ∇²)⁻¹(ν Δt ∇²)u'. The function _rhs_transform subtracts
# the linear part s.t. fast_diagonalization solves
# u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u'.
v_diffused = list()
if boundaries.has_all_periodic_boundary_conditions(*v):
circulant = True
else:
circulant = False
# only matmul implementation supports non-circulant matrices
implementation = 'matmul'
for u in v:
laplacians = array_utils.laplacian_matrix_w_boundaries(
u.grid, u.offset, u.bc)
op = fast_diagonalization.transform(
func,
laplacians,
v[0].dtype,
hermitian=True,
circulant=circulant,
implementation=implementation)
u_interior = u.bc.trim_boundary(u.array)
u_interior_transformed = _rhs_transform(u_interior, u.bc)
u_dt_diffused = grids.GridArray(
op(u_interior_transformed), u_interior.offset, u_interior.grid)
u_diffused = u_interior + u_dt_diffused
u_diffused = u.bc.pad_and_impose_bc(u_diffused, offset_to_pad_to=u.offset)
v_diffused.append(u_diffused)
return tuple(v_diffused)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment