Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.fast_diagonalization."""
from typing import Sequence
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import test_util
import numpy as np
import scipy.linalg
Array = fast_diagonalization.Array
def apply_operators(
operators: Sequence[np.ndarray],
rhs: Array,
) -> Array:
"""Apply a sum of linear operators along all array axes."""
assert len(operators) == rhs.ndim
out = 0
for axis, matrix in enumerate(operators):
axes = [i if i != axis else rhs.ndim for i in range(rhs.ndim)]
out += jnp.einsum(
matrix, [axis, rhs.ndim], rhs, list(range(rhs.ndim)), axes)
return out
class FastDiagonalizationTest(test_util.TestCase):
def test_random_1d_matmul(self):
rs = np.random.RandomState(0)
a = rs.randn(3, 3)
a = jnp.array(a + a.T, np.float32)
b = rs.randn(3).astype(np.float32)
a_inv = fast_diagonalization.pseudoinverse(
[a], b.dtype, hermitian=True, implementation='matmul')
actual = a_inv(b)
expected = jnp.linalg.solve(a, b)
self.assertAllClose(actual, expected, atol=1e-6)
@parameterized.parameters('fft', 'rfft')
def test_random_1d_fft(self, implementation):
rs = np.random.RandomState(0)
a = jnp.array(scipy.linalg.circulant(rs.randn(4)), np.float32)
b = rs.randn(4).astype(np.float32)
a_inv = fast_diagonalization.pseudoinverse(
[a], b.dtype, circulant=True, implementation=implementation)
actual = a_inv(b)
expected = jnp.linalg.solve(a, b)
self.assertAllClose(actual, expected, atol=1e-5)
@parameterized.parameters(
*[(ndim, 'matmul') for ndim in [1, 2, 3]],
*[(ndim, 'fft') for ndim in [1, 2, 3]],
*[(ndim, 'rfft') for ndim in [1, 2, 3]],
)
def test_identity_nd(self, ndim, implementation):
rs = np.random.RandomState(0)
b = rs.randn(*(2, 4, 6)[:ndim]).astype(np.float32)
ops = [np.eye(2), 2 * np.eye(4), 3 * np.eye(6)]
a_inv = fast_diagonalization.pseudoinverse(
ops[:ndim], b.dtype, hermitian=True, circulant=True,
implementation=implementation)
actual = a_inv(b)
expected = b / sum(range(1, 1 + ndim))
self.assertAllClose(actual, expected, rtol=1e-5, atol=1e-5)
@parameterized.parameters('matmul', 'fft', 'rfft')
def test_poisson_1d(self, implementation):
rs = np.random.RandomState(0)
a = jnp.array([[-2, 1, 0, 1], [1, -2, 1, 0], [0, 1, -2, 1], [1, 0, 1, -2]],
np.float32)
b = rs.randn(4).astype(np.float32)
a_inv = fast_diagonalization.pseudoinverse(
[a], a.dtype, hermitian=True, circulant=True,
implementation=implementation)
x = a_inv(b)
self.assertAllClose(jnp.dot(a, x), b - b.mean(), atol=1e-5)
@parameterized.parameters(
dict(periodic_x=False, periodic_y=False),
dict(periodic_x=False, periodic_y=True),
dict(periodic_x=True, periodic_y=True),
)
def test_poisson_2d_matmul(self, periodic_x, periodic_y):
a1 = jnp.array([[-2, 1, 0, periodic_x], [1, -2, 1, 0], [0, 1, -2, 1],
[periodic_x, 0, 1, -2]], dtype=np.float32)
a2 = jnp.array([[-2, 1, periodic_y], [1, -2, 1], [periodic_y, 1, -2]],
dtype=np.float32)
b = np.random.RandomState(0).randn(4, 3).astype(np.float32)
operators = [a1, a2]
a_inv = fast_diagonalization.pseudoinverse(
operators, b.dtype, hermitian=True)
x = a_inv(b)
actual = apply_operators(operators, x)
expected = b.copy()
if periodic_x and periodic_y:
expected -= expected.mean()
self.assertAllClose(actual, expected, atol=1e-5)
@parameterized.parameters('fft', 'rfft')
def test_poisson_2d_fft(self, implementation):
a1 = array_utils.laplacian_matrix(size=4, step=1.0)
a2 = array_utils.laplacian_matrix(size=6, step=1.0)
b = np.random.RandomState(0).randn(4, 6).astype(np.float32)
operators = [a1, a2]
a_inv = fast_diagonalization.pseudoinverse(
operators, b.dtype, circulant=True, implementation=implementation)
x = a_inv(b)
actual = apply_operators(operators, x)
expected = b - b.mean()
self.assertAllClose(actual, expected, atol=1e-5)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for spectral filtering."""
from typing import Callable
import jax.numpy as jnp
from jax_cfd.base import grids
Array = grids.Array
def _angular_frequency_magnitude(grid: grids.Grid) -> Array:
frequencies = [2 * jnp.pi * jnp.fft.fftfreq(size, step)
for size, step in zip(grid.shape, grid.step)]
freq_vector = jnp.stack(jnp.meshgrid(*frequencies, indexing='ij'), axis=0)
return jnp.linalg.norm(freq_vector, axis=0)
def filter( # pylint: disable=redefined-builtin
spectral_density: Callable[[Array], Array],
array: Array,
grid: grids.Grid,
) -> Array:
"""Filter an Array with white noise to match a prescribed spectral density."""
k = _angular_frequency_magnitude(grid)
filters = jnp.where(k > 0, spectral_density(k), 0.0)
# The output signal can safely be assumed to be real if our input signal was
# real, because our spectral density only depends on norm(k).
return jnp.fft.ifftn(jnp.fft.fftn(array) * filters).real
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for approximating derivatives.
Finite difference methods operate on GridVariable and return GridArray.
Evaluating finite differences requires boundary conditions, which are defined
for a GridVariable. But the operation of taking a derivative makes the boundary
condition undefined, which is why the return type is GridArray.
For example, if the variable c has the boundary condition c_b = 0, and we take
the derivate dc/dx, then it is unclear what the boundary condition on dc/dx
should be. So programmatically, after taking finite differences and doing
operations, the user has to explicitly assign boundary conditions to the result.
Example:
c = GridVariable(c_array, c_boundary_condition)
dcdx = finite_differences.forward_difference(c) # returns GridArray
c_new = c + dt * (-velocity * dcdx) # operations on GridArrays
c_new = GridVariable(c_new, c_boundary_condition) # assocaite BCs
"""
import typing
from typing import Optional, Sequence, Tuple
from jax_cfd.base import grids
from jax_cfd.base import interpolation
import numpy as np
GridArray = grids.GridArray
GridVariable = grids.GridVariable
GridArrayTensor = grids.GridArrayTensor
def stencil_sum(*arrays: GridArray) -> GridArray:
"""Sum arrays across a stencil, with an averaged offset."""
# pylint: disable=line-too-long
offset = grids.averaged_offset(*arrays)
# pytype appears to have a bad type signature for sum():
# Built-in function sum was called with the wrong arguments [wrong-arg-types]
# Expected: (iterable: Iterable[complex])
# Actually passed: (iterable: Generator[Union[jax.interpreters.xla.DeviceArray, numpy.ndarray], Any, None])
result = sum(array.data for array in arrays) # type: ignore
grid = grids.consistent_grid(*arrays)
return grids.GridArray(result, offset, grid)
# incompatible with typing.overload
# pylint: disable=pointless-statement
# pylint: disable=function-redefined
# pylint: disable=unused-argument
@typing.overload
def central_difference(u: GridVariable, axis: int) -> GridArray:
...
@typing.overload
def central_difference(
u: GridVariable, axis: Optional[Tuple[int, ...]]) -> Tuple[GridArray, ...]:
...
def central_difference(u, axis=None):
"""Approximates grads with central differences."""
if axis is None:
axis = range(u.grid.ndim)
if not isinstance(axis, int):
return tuple(central_difference(u, a) for a in axis) # pytype: disable=wrong-arg-types # always-use-return-annotations
diff = stencil_sum(u.shift(+1, axis), -u.shift(-1, axis))
return diff / (2 * u.grid.step[axis])
@typing.overload
def backward_difference(u: GridVariable, axis: int) -> GridArray:
...
@typing.overload
def backward_difference(
u: GridVariable, axis: Optional[Tuple[int, ...]]) -> Tuple[GridArray, ...]:
...
def backward_difference(u, axis=None):
"""Approximates grads with finite differences in the backward direction."""
if axis is None:
axis = range(u.grid.ndim)
if not isinstance(axis, int):
return tuple(backward_difference(u, a) for a in axis) # pytype: disable=wrong-arg-types # always-use-return-annotations
diff = stencil_sum(u.array, -u.shift(-1, axis))
return diff / u.grid.step[axis]
@typing.overload
def forward_difference(u: GridVariable, axis: int) -> GridArray:
...
@typing.overload
def forward_difference(
u: GridVariable,
axis: Optional[Tuple[int, ...]] = ...) -> Tuple[GridArray, ...]:
...
def forward_difference(u, axis=None):
"""Approximates grads with finite differences in the forward direction."""
if axis is None:
axis = range(u.grid.ndim)
if not isinstance(axis, int):
return tuple(forward_difference(u, a) for a in axis) # pytype: disable=wrong-arg-types # always-use-return-annotations
diff = stencil_sum(u.shift(+1, axis), -u.array)
return diff / u.grid.step[axis]
def laplacian(u: GridVariable) -> GridArray:
"""Approximates the Laplacian of `u`."""
scales = np.square(1 / np.array(u.grid.step, dtype=u.dtype))
result = -2 * u.array * np.sum(scales)
for axis in range(u.grid.ndim):
result += stencil_sum(u.shift(-1, axis), u.shift(+1, axis)) * scales[axis]
return result
def divergence(v: Sequence[GridVariable]) -> GridArray:
"""Approximates the divergence of `v` using backward differences."""
grid = grids.consistent_grid(*v)
if len(v) != grid.ndim:
raise ValueError('The length of `v` must be equal to `grid.ndim`.'
f'Expected length {grid.ndim}; got {len(v)}.')
differences = [backward_difference(u, axis) for axis, u in enumerate(v)]
return sum(differences)
def centered_divergence(v: Sequence[GridVariable]) -> GridArray:
"""Approximates the divergence of `v` using centered differences."""
grid = grids.consistent_grid(*v)
if len(v) != grid.ndim:
raise ValueError('The length of `v` must be equal to `grid.ndim`.'
f'Expected length {grid.ndim}; got {len(v)}.')
differences = [central_difference(u, axis) for axis, u in enumerate(v)]
return sum(differences)
@typing.overload
def gradient_tensor(v: GridVariable) -> GridArrayTensor:
...
@typing.overload
def gradient_tensor(v: Sequence[GridVariable]) -> GridArrayTensor:
...
def gradient_tensor(v):
"""Approximates the cell-centered gradient of `v`."""
if not isinstance(v, GridVariable):
return GridArrayTensor(np.stack([gradient_tensor(u) for u in v], axis=-1)) # pytype: disable=wrong-arg-types # always-use-return-annotations
grad = []
for axis in range(v.grid.ndim):
offset = v.offset[axis]
if offset == 0:
derivative = forward_difference(v, axis)
elif offset == 1:
derivative = backward_difference(v, axis)
elif offset == 0.5:
v_centered = interpolation.linear(v, v.grid.cell_center)
derivative = central_difference(v_centered, axis)
else:
raise ValueError(f'expected offset values in {{0, 0.5, 1}}, got {offset}')
grad.append(derivative)
return GridArrayTensor(grad)
def curl_2d(v: Sequence[GridVariable]) -> GridArray:
"""Approximates the curl of `v` in 2D using forward differences."""
if len(v) != 2:
raise ValueError(f'Length of `v` is not 2: {len(v)}')
grid = grids.consistent_grid(*v)
if grid.ndim != 2:
raise ValueError(f'Grid dimensionality is not 2: {grid.ndim}')
return forward_difference(v[1], axis=0) - forward_difference(v[0], axis=1)
def curl_3d(
v: Sequence[GridVariable]) -> Tuple[GridArray, GridArray, GridArray]:
"""Approximates the curl of `v` in 2D using forward differences."""
if len(v) != 3:
raise ValueError(f'Length of `v` is not 3: {len(v)}')
grid = grids.consistent_grid(*v)
if grid.ndim != 3:
raise ValueError(f'Grid dimensionality is not 3: {grid.ndim}')
curl_x = (forward_difference(v[2], axis=1) - forward_difference(v[1], axis=2))
curl_y = (forward_difference(v[0], axis=2) - forward_difference(v[2], axis=0))
curl_z = (forward_difference(v[1], axis=0) - forward_difference(v[0], axis=1))
return (curl_x, curl_y, curl_z)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.grids."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
def _trim_boundary(array):
if isinstance(array, grids.GridArray):
data = array.data[(slice(1, -1),) * array.data.ndim]
return grids.GridArray(data, array.offset, array.grid)
else:
return jnp.asarray(array)[(slice(1, -1),) * array.ndim]
def periodic_grid_variable(data, offset, grid):
return grids.GridVariable(
array=grids.GridArray(data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
class FiniteDifferenceTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_central_difference_periodic',
method=fd.central_difference,
shape=(3, 3, 3),
step=(1., 1., 1.),
expected_offset=0,
negative_shift=-1,
positive_shift=1),
dict(testcase_name='_backward_difference_periodic',
method=fd.backward_difference,
shape=(2, 3, 4),
step=(.1, .3, 1.),
expected_offset=-0.5,
negative_shift=-1,
positive_shift=0),
dict(testcase_name='_forward_difference_periodic',
method=fd.forward_difference,
shape=(23, 32, 1),
step=(.01, 2., .1),
expected_offset=+0.5,
negative_shift=0,
positive_shift=1),
)
def test_finite_difference_indexing(
self, method, shape, step, expected_offset, negative_shift,
positive_shift):
"""Tests finite difference code using explicit indices."""
grid = grids.Grid(shape, step)
u = periodic_grid_variable(
jnp.arange(np.prod(shape)).reshape(shape), (0, 0, 0), grid)
actual_x, actual_y, actual_z = method(u)
x, y, z = jnp.meshgrid(*[jnp.arange(s) for s in shape], indexing='ij')
diff_x = (u.data[jnp.roll(x, -positive_shift, axis=0), y, z] -
u.data[jnp.roll(x, -negative_shift, axis=0), y, z])
expected_data_x = diff_x / step[0] / (positive_shift - negative_shift)
expected_x = grids.GridArray(expected_data_x, (expected_offset, 0, 0), grid)
diff_y = (u.data[x, jnp.roll(y, -positive_shift, axis=1), z] -
u.data[x, jnp.roll(y, -negative_shift, axis=1), z])
expected_data_y = diff_y / step[1] / (positive_shift - negative_shift)
expected_y = grids.GridArray(expected_data_y, (0, expected_offset, 0), grid)
diff_z = (u.data[x, y, jnp.roll(z, -positive_shift, axis=2)] -
u.data[x, y, jnp.roll(z, -negative_shift, axis=2)])
expected_data_z = diff_z / step[2] / (positive_shift - negative_shift)
expected_z = grids.GridArray(expected_data_z, (0, 0, expected_offset), grid)
self.assertArrayEqual(expected_x, actual_x)
self.assertArrayEqual(expected_y, actual_y)
self.assertArrayEqual(expected_z, actual_z)
@parameterized.named_parameters(
dict(testcase_name='_central_difference_periodic',
method=fd.central_difference,
shape=(100, 100, 100),
offset=(0, 0, 0),
f=lambda x, y, z: jnp.cos(x) * jnp.cos(y) * jnp.sin(z),
gradf=(lambda x, y, z: -jnp.sin(x) * jnp.cos(y) * jnp.sin(z),
lambda x, y, z: -jnp.cos(x) * jnp.sin(y) * jnp.sin(z),
lambda x, y, z: jnp.cos(x) * jnp.cos(y) * jnp.cos(z)),
atol=1e-3),
dict(testcase_name='_backward_difference_periodic',
method=fd.backward_difference,
shape=(100, 100, 100),
offset=(0, 0, 0),
f=lambda x, y, z: jnp.cos(x) * jnp.cos(y) * jnp.sin(z),
gradf=(lambda x, y, z: -jnp.sin(x) * jnp.cos(y) * jnp.sin(z),
lambda x, y, z: -jnp.cos(x) * jnp.sin(y) * jnp.sin(z),
lambda x, y, z: jnp.cos(x) * jnp.cos(y) * jnp.cos(z)),
atol=5e-2),
dict(testcase_name='_forward_difference_periodic',
method=fd.forward_difference,
shape=(200, 200, 200),
offset=(0, 0, 0),
f=lambda x, y, z: jnp.cos(x) * jnp.cos(y) * jnp.sin(z),
gradf=(lambda x, y, z: -jnp.sin(x) * jnp.cos(y) * jnp.sin(z),
lambda x, y, z: -jnp.cos(x) * jnp.sin(y) * jnp.sin(z),
lambda x, y, z: jnp.cos(x) * jnp.cos(y) * jnp.cos(z)),
atol=5e-2),
)
def test_finite_difference_analytic(
self, method, shape, offset, f, gradf, atol):
"""Tests finite difference code comparing to the analytice solution."""
step = tuple([2. * jnp.pi / s for s in shape])
grid = grids.Grid(shape, step)
mesh = grid.mesh()
u = periodic_grid_variable(f(*mesh), offset, grid)
expected_grad = jnp.stack([df(*mesh) for df in gradf])
actual_grad = [array.data for array in method(u)]
self.assertAllClose(expected_grad, actual_grad, atol=atol)
@parameterized.named_parameters(
dict(testcase_name='_2D_constant',
shape=(20, 20),
f=lambda x, y: np.ones_like(x),
g=lambda x, y: np.zeros_like(x),
atol=1e-3),
dict(testcase_name='_2D_quadratic',
shape=(21, 21),
f=lambda x, y: x * (x - 1.) + y * (y - 1.),
g=lambda x, y: 4 * np.ones_like(x),
atol=1e-3),
dict(testcase_name='_3D_quadratic',
shape=(13, 13, 13),
f=lambda x, y, z: x * (x - 1.) + y * (y - 1.) + z * (z - 1.),
g=lambda x, y, z: 6 * np.ones_like(x),
atol=1e-3),
)
def test_laplacian(self, shape, f, g, atol):
step = tuple([1. / s for s in shape])
grid = grids.Grid(shape, step)
offset = (0,) * len(shape)
mesh = grid.mesh(offset)
u = periodic_grid_variable(f(*mesh), offset, grid)
expected_laplacian = _trim_boundary(grids.GridArray(g(*mesh), offset, grid))
actual_laplacian = _trim_boundary(fd.laplacian(u))
self.assertAllClose(expected_laplacian, actual_laplacian, atol=atol)
@parameterized.named_parameters(
dict(testcase_name='_2D_constant',
shape=(20, 20),
offsets=((0.5, 0), (0, 0.5)),
f=lambda x, y: (np.ones_like(x), np.ones_like(y)),
g=lambda x, y: jnp.zeros_like(x),
atol=1e-3),
dict(testcase_name='_2D_quadratic',
shape=(21, 21),
offsets=((0.5, 0), (0, 0.5)),
f=lambda x, y: (x * (x - 1.), y * (y - 1.)),
g=lambda x, y: 2 * x + 2 * y - 2,
atol=0.1),
dict(testcase_name='_3D_quadratic',
shape=(13, 13, 13),
offsets=((0.5, 0, 0), (0, 0.5, 0), (0, 0, 0.5)),
f=lambda x, y, z: (x * (x - 1.), y * (y - 1.), z * (z - 1.)),
g=lambda x, y, z: 2 * x + 2 * y + 2 * z - 3,
atol=0.25),
)
def test_divergence(self, shape, offsets, f, g, atol):
step = tuple([1. / s for s in shape])
grid = grids.Grid(shape, step)
v = [periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid)
for axis, offset in enumerate(offsets)]
expected_divergence = _trim_boundary(
grids.GridArray(g(*grid.mesh()), (0,) * grid.ndim, grid))
actual_divergence = _trim_boundary(fd.divergence(v))
self.assertAllClose(expected_divergence, actual_divergence, atol=atol)
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
dict(
testcase_name='_2D_constant',
shape=(20, 20),
f=lambda x, y: (np.ones_like(x), np.ones_like(y)),
g=lambda x, y: np.array([[jnp.zeros_like(x)] * 2] * 2),
atol=0),
dict(
testcase_name='_2D_quadratic',
shape=(21, 21),
f=lambda x, y: (x * (y - 1.), y * (x - 2.)),
g=lambda x, y: np.array([[y - 1., y], [x, x - 2.]]),
atol=3e-6),
dict(
testcase_name='_2D_quartic',
shape=(21, 21),
f=lambda x, y: (x**2 * y**2, (x - 1.)**3 * (y - 2.)),
g=lambda x, y: np.array([[2 * x * y**2, 3 * (x - 1.)**2 *
(y - 2.)], [2 * x**2 * y, (x - 1.)**3]]),
atol=1e-2),
dict(
testcase_name='_3D_quadratic',
shape=(13, 13, 13),
f=lambda x, y, z: (x * (y - 1.), y * (z - 2.), z * (x - 3.)),
g=lambda x, y, z: np.array([[y - 1., jnp.zeros_like(x), z],
[x, z - 2., jnp.zeros_like(x)],
[jnp.zeros_like(x), y, x - 3.]]),
atol=4e-6),
)
# pylint: enable=g-long-lambda
def test_cell_centered_gradient(self, shape, f, g, atol):
step = tuple([1. / s for s in shape])
grid = grids.Grid(shape, step)
with self.subTest('cell center values'):
offsets = (grid.cell_center,) * grid.ndim
v = [
periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid)
for axis, offset in enumerate(offsets)
]
expected_gradient = g(*grid.mesh())
actual_gradient = fd.gradient_tensor(v)
for i in range(grid.ndim):
for j in range(len(v)):
print('i and j are', i, j)
expected = _trim_boundary(expected_gradient[i, j])
actual = _trim_boundary(actual_gradient[i, j])
self.assertAllClose(expected, actual.data, atol=atol)
with self.subTest('cell face values'):
offsets = grid.cell_faces
v = [periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid)
for axis, offset in enumerate(offsets)]
expected_gradient = g(*grid.mesh())
actual_gradient = fd.gradient_tensor(v)
for i in range(grid.ndim):
for j in range(len(v)):
print('i and j are', i, j)
expected = _trim_boundary(expected_gradient[i, j])
actual = _trim_boundary(actual_gradient[i, j])
self.assertAllClose(expected, actual.data, atol=atol)
with self.subTest('raises'):
offsets = ((0.1,) * grid.ndim,) * grid.ndim # unsupported offset
v = [
periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid)
for axis, offset in enumerate(offsets)
]
with self.assertRaisesRegex(ValueError, 'expected offset values'):
fd.gradient_tensor(v)
@parameterized.named_parameters(
# https://en.wikipedia.org/wiki/Curl_(mathematics)#Examples
dict(testcase_name='_wikipedia_example_1',
shape=(20, 20),
offsets=((0.5, 0), (0, 0.5)),
f=lambda x, y: (y, -x),
g=lambda x, y: -2 * np.ones_like(x),
atol=1e-3),
dict(testcase_name='_wikipedia_example_2',
shape=(21, 21),
offsets=((0.5, 0), (0, 0.5)),
f=lambda x, y: (np.ones_like(x), -x**2),
g=lambda x, y: -2 * x,
atol=1e-3),
)
def test_curl_2d(self, shape, offsets, f, g, atol):
step = tuple([1. / s for s in shape])
grid = grids.Grid(shape, step)
v = [periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid)
for axis, offset in enumerate(offsets)]
result_offset = (0.5, 0.5)
expected_curl = _trim_boundary(
grids.GridArray(g(*grid.mesh(result_offset)), result_offset, grid))
actual_curl = _trim_boundary(fd.curl_2d(v))
self.assertAllClose(expected_curl, actual_curl, atol=atol)
@parameterized.named_parameters(
# https://www.web-formulas.com/Math_Formulas/Linear_Algebra_Curl_of_a_Vector_Field.aspx
dict(testcase_name='_web_formulas_example_3',
shape=(13, 13, 13),
offsets=((0.5, 0, 0), (0, 0.5, 0), (0, 0, 0.5)),
expected_offsets=((0, 0.5, 0.5), (0.5, 0, 0.5), (0.5, 0.5, 0)),
f=lambda x, y, z: (x + y + z, x - y - z, x**2 + y**2 + z**2),
g=lambda x, y, z: (2 * y + 1, 1 - 2 * x, np.zeros_like(x)),
atol=1e-3),
)
def test_curl_3d(
self, shape, offsets, expected_offsets, f, g, atol):
step = tuple([1. / s for s in shape])
grid = grids.Grid(shape, step)
v = [periodic_grid_variable(f(*grid.mesh(offset))[axis], offset, grid)
for axis, offset in enumerate(offsets)]
expected_curl = [
_trim_boundary(
grids.GridArray(g(*grid.mesh(offset))[axis], offset, grid))
for axis, offset in enumerate(expected_offsets)
]
actual_curl = list(map(_trim_boundary, fd.curl_3d(v)))
self.assertEqual(len(actual_curl), 3)
self.assertAllClose(expected_curl[0], actual_curl[0], atol=atol)
self.assertAllClose(expected_curl[1], actual_curl[1], atol=atol)
self.assertAllClose(expected_curl[2], actual_curl[2], atol=atol)
if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forcing functions for Navier-Stokes equations."""
# TODO(jamieas): change the signature for all forcing functions so that they
# close over `grid`.
import functools
from typing import Callable, Optional, Tuple
import jax.numpy as jnp
from jax_cfd.base import equations
from jax_cfd.base import filter_utils
from jax_cfd.base import grids
from jax_cfd.base import validation_problems
Array = grids.Array
GridArrayVector = grids.GridArrayVector
GridVariableVector = grids.GridVariableVector
ForcingFn = Callable[[GridVariableVector], GridArrayVector]
def taylor_green_forcing(
grid: grids.Grid, scale: float = 1, k: int = 2,
) -> ForcingFn:
"""Constant driving forced in the form of Taylor-Green vorcities."""
u, v = validation_problems.TaylorGreen(
shape=grid.shape[:2], kx=k, ky=k).velocity()
# Put force on same offset, grid as velocity components
if grid.ndim == 2:
u = grids.GridArray(u.data * scale, u.offset, grid)
v = grids.GridArray(v.data * scale, v.offset, grid)
f = (u, v)
elif grid.ndim == 3:
# append z-dimension to u,v arrays
u_data = jnp.broadcast_to(jnp.expand_dims(u.data * scale, -1), grid.shape)
v_data = jnp.broadcast_to(jnp.expand_dims(v.data * scale, -1), grid.shape)
u = grids.GridArray(u_data, (1, 0.5, 0.5), grid)
v = grids.GridArray(v_data, (0.5, 1, 0.5), grid)
w = grids.GridArray(jnp.zeros_like(u.data), (0.5, 0.5, 1), grid)
f = (u, v, w)
else:
raise NotImplementedError
def forcing(v):
del v
return f
return forcing
def kolmogorov_forcing(
grid: grids.Grid,
scale: float = 1,
k: int = 2,
swap_xy: bool = False,
offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> ForcingFn:
"""Returns the Kolmogorov forcing function for turbulence in 2D."""
if offsets is None:
offsets = grid.cell_faces
if swap_xy:
x = grid.mesh(offsets[1])[0]
v = scale * grids.GridArray(jnp.sin(k * x), offsets[1], grid)
if grid.ndim == 2:
u = grids.GridArray(jnp.zeros_like(v.data), (1, 1/2), grid)
f = (u, v)
elif grid.ndim == 3:
u = grids.GridArray(jnp.zeros_like(v.data), (1, 1/2, 1/2), grid)
w = grids.GridArray(jnp.zeros_like(u.data), (1/2, 1/2, 1), grid)
f = (u, v, w)
else:
raise NotImplementedError
else:
y = grid.mesh(offsets[0])[1]
u = scale * grids.GridArray(jnp.sin(k * y), offsets[0], grid)
if grid.ndim == 2:
v = grids.GridArray(jnp.zeros_like(u.data), (1/2, 1), grid)
f = (u, v)
elif grid.ndim == 3:
v = grids.GridArray(jnp.zeros_like(u.data), (1/2, 1, 1/2), grid)
w = grids.GridArray(jnp.zeros_like(u.data), (1/2, 1/2, 1), grid)
f = (u, v, w)
else:
raise NotImplementedError
def forcing(v):
del v
return f
return forcing
def linear_forcing(grid, coefficient: float) -> ForcingFn:
"""Linear forcing, proportional to velocity."""
del grid
def forcing(v):
return tuple(coefficient * u.array for u in v)
return forcing
def no_forcing(grid):
"""Zero-valued forcing field for unforced simulations."""
del grid
def forcing(v):
return tuple(0 * u.array for u in v)
return forcing
def sum_forcings(*forcings: ForcingFn) -> ForcingFn:
"""Sum multiple forcing functions."""
def forcing(v):
return equations.sum_fields(*[forcing(v) for forcing in forcings])
return forcing
FORCING_FUNCTIONS = dict(kolmogorov=kolmogorov_forcing,
taylor_green=taylor_green_forcing)
def simple_turbulence_forcing(
grid: grids.Grid,
constant_magnitude: float = 0,
constant_wavenumber: int = 2,
linear_coefficient: float = 0,
forcing_type: str = 'kolmogorov',
) -> ForcingFn:
"""Returns a forcing function for turbulence in 2D or 3D.
2D turbulence needs a driving force injecting energy at intermediate
length-scales, and a damping force at long length-scales to avoid all energy
accumulating in giant vorticies. This can be achieved with
`constant_magnitude > 0` and `linear_coefficient < 0`.
3D turbulence only needs a driving force at the longest length-scale (damping
happens at the smallest length-scales due to viscosity and/or numerical
dispersion). This can be achieved with `constant_magnitude = 0` and
`linear_coefficient > 0`.
Args:
grid: grid on which to simulate.
constant_magnitude: magnitude for constant forcing with Taylor-Green
vortices.
constant_wavenumber: wavenumber for constant forcing with Taylor-Green
vortices.
linear_coefficient: forcing coefficient proportional to velocity, for
either driving or damping based on the sign.
forcing_type: String that specifies forcing. This must specify the name of
function declared in FORCING_FUNCTIONS (taylor_green, etc.)
Returns:
Forcing function.
"""
linear_force = linear_forcing(grid, linear_coefficient)
constant_force_fn = FORCING_FUNCTIONS.get(forcing_type)
if constant_force_fn is None:
raise ValueError('Unknown `forcing_type`. '
f'Expected one of {list(FORCING_FUNCTIONS.keys())}; '
f'got {forcing_type}.')
constant_force = constant_force_fn(grid, constant_magnitude,
constant_wavenumber)
return sum_forcings(linear_force, constant_force)
def filtered_forcing(
spectral_density: Callable[[Array], Array],
grid: grids.Grid,
) -> ForcingFn:
"""Apply forcing as a function of angular frequency.
Args:
spectral_density: if `x_hat` is a Fourier component of the velocity with
angular frequency `k` then the forcing applied to `x_hat` is
`spectral_density(k)`.
grid: object representing spatial discretization.
Returns:
A forcing function that applies filtered forcing.
"""
def forcing(v):
filter_ = grids.applied(
functools.partial(filter_utils.filter, spectral_density, grid=grid))
return tuple(filter_(u.array) for u in v)
return forcing
def filtered_linear_forcing(
lower_wavenumber: float,
upper_wavenumber: float,
coefficient: float,
grid: grids.Grid,
) -> ForcingFn:
"""Apply linear forcing to low frequency components of the velocity field.
Args:
lower_wavenumber: the minimum wavenumber to which forcing should be
applied.
upper_wavenumber: the maximum wavenumber to which forcing should be
applied.
coefficient: the linear coefficient for forcing applied to components with
wavenumber below `threshold`.
grid: object representing spatial discretization.
Returns:
A forcing function that applies filtered linear forcing.
"""
def spectral_density(k):
return jnp.where(((k >= 2 * jnp.pi * lower_wavenumber) &
(k <= 2 * jnp.pi * upper_wavenumber)),
coefficient,
0)
return filtered_forcing(spectral_density, grid)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.forcings."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
def _make_zero_velocity_field(grid):
ndim = grid.ndim
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
bc = boundaries.periodic_boundary_conditions(grid.ndim)
return tuple(
grids.GridVariable(
grids.GridArray(jnp.zeros(grid.shape), tuple(offset), grid), bc)
for ax, offset in enumerate(offsets))
class ForcingsTest(test_util.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='_taylor_green_forcing',
partial_force_fn=functools.partial(
forcings.taylor_green_forcing, scale=1.0, k=2),
),
dict(
testcase_name='_kolmogorov_forcing',
partial_force_fn=functools.partial(
forcings.kolmogorov_forcing, scale=1.0, k=2),
),
dict(
testcase_name='_linear_forcing',
partial_force_fn=functools.partial(
forcings.linear_forcing, coefficient=2.0),
),
dict(
testcase_name='_no_forcing',
partial_force_fn=functools.partial(forcings.no_forcing)),
dict(
testcase_name='_simple_turbulence_forcing',
partial_force_fn=functools.partial(
forcings.simple_turbulence_forcing,
constant_magnitude=0.0,
constant_wavenumber=2,
linear_coefficient=0.0,
forcing_type='kolmogorov'),
),
)
def test_forcing_function(self, partial_force_fn):
for ndim in [2, 3]:
with self.subTest(f'ndim={ndim}'):
grid = grids.Grid((16,) * ndim)
v = _make_zero_velocity_field(grid)
force_fn = partial_force_fn(grid)
force = force_fn(v)
# Check that offset and grid match velocity input
for d in range(ndim):
self.assertAllClose(0 * force[d], v[d].array)
def test_sum_forcings(self):
grid = grids.Grid((16, 16))
force_fn_1 = forcings.kolmogorov_forcing(grid, scale=1.0, k=2)
force_fn_2 = forcings.no_forcing(grid)
force_fn_sum = forcings.sum_forcings(force_fn_1, force_fn_2)
v = _make_zero_velocity_field(grid)
force_1 = force_fn_1(v)
force_sum = force_fn_sum(v)
for d in range(grid.ndim):
self.assertArrayEqual(force_sum[d], force_1[d])
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
dict(testcase_name='_low_frequency_unchanged_2D',
ndim=2,
grid_size=16,
lower_wavenumber=0,
upper_wavenumber=2,
coefficient=1,
# velocity is concentrated on wavenumber < sqrt(2)
# so we expect it to pass through the filter.
velocity_function=lambda x, y: (jnp.cos(2 * jnp.pi * x),
jnp.cos(2 * jnp.pi * y)),
expected_force_function=lambda x, y: (jnp.cos(2 * jnp.pi * x),
jnp.cos(2 * jnp.pi * y))),
dict(testcase_name='_high_frequency_zeros_3D',
ndim=3,
grid_size=16,
lower_wavenumber=0,
upper_wavenumber=1,
coefficient=1,
# velocity is concentrated on wave numbers 2 to 2 * sqrt(3)
# so we expect it to be filtered entirely.
velocity_function=lambda x, y, z: (jnp.cos(4 * jnp.pi * x),
jnp.cos(4 * jnp.pi * y),
jnp.cos(4 * jnp.pi * z),),
expected_force_function=lambda x, y, z: (jnp.zeros_like(x),
jnp.zeros_like(y),
jnp.zeros_like(z))),
)
def test_filtered_linear_forcing(self,
ndim,
grid_size,
lower_wavenumber,
upper_wavenumber,
coefficient,
velocity_function,
expected_force_function):
grid = grids.Grid((grid_size,) * ndim,
domain=((0, 1),) * ndim)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
velocity = tuple(
grids.GridVariable(grids.GridArray(u, offset, grid), bc)
for u, offset in zip(velocity_function(*grid.mesh()), grid.cell_faces))
expected_force = expected_force_function(*grid.mesh())
actual_force = forcings.filtered_linear_forcing(
lower_wavenumber, upper_wavenumber, coefficient, grid)(velocity)
for expected, actual in zip(expected_force, actual_force):
self.assertAllClose(expected, actual.data, atol=1e-5)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX utility functions for JAX-CFD."""
import contextlib
from typing import Any, Callable, Sequence
import jax
from jax import tree_util
import jax.numpy as jnp
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
# Not accurate for contextmanager
# pylint: disable=g-doc-return-or-yield
# There is currently no good way to indicate a jax "pytree" with arrays at its
# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more
# information about PyTrees and https://github.com/google/jax/issues/3340 for
# discussion of this issue.
PyTree = Any
_INITIALIZING = 0
@contextlib.contextmanager
def init_context():
"""Creates a context in which scan() only evaluates f() once.
This is useful for initializing a neural net with Haiku that involves modules
that are applied inside scan(). Within init_context(), these modules are only
called once. This allows us to preserve the pre-omnistaging behavior of JAX,
e.g., so we can initialize a neural net module pass directly into a scanned
function.
"""
global _INITIALIZING
_INITIALIZING += 1
try:
yield
finally:
_INITIALIZING -= 1
def _tree_stack(trees: Sequence[PyTree]) -> PyTree:
if trees:
return tree_util.tree_map(lambda *xs: jnp.stack(xs), *trees)
else:
return trees
def scan(f, init, xs, length=None):
"""A version of jax.lax.scan that supports init_context()."""
# Note: we use our own version of scan rather than haiku.scan() because
# haiku.scan() only support use inside haiku modules, but we want to be able
# to use the same scan function even when not using haiku.
if _INITIALIZING:
xs_flat, treedef = tree_util.tree_flatten(xs)
if length is None:
length, = {x.shape[0] for x in xs_flat}
x0 = tree_util.tree_unflatten(treedef, [x[0, ...] for x in xs_flat])
carry, y0 = f(init, x0)
# Create a dummy-output of the right shape while only calling f() once.
ys = _tree_stack(length * [y0])
return carry, ys
return jax.lax.scan(f, init, xs, length)
def repeated(f: Callable, steps: int) -> Callable:
"""Returns a repeatedly applied version of f()."""
def f_repeated(x_initial):
g = lambda x, _: (f(x), None)
x_final, _ = scan(g, x_initial, xs=None, length=steps)
return x_final
return f_repeated
def _identity(x):
return x
def trajectory(
step_fn: Callable,
steps: int,
post_process: Callable = _identity,
*,
start_with_input: bool = False,
) -> Callable:
"""Returns a function that accumulates repeated applications of `step_fn`.
Args:
step_fn: function that takes a state and returns state after one time step.
steps: number of steps to take when generating the trajectory.
post_process: transformation to be applied to each frame of the trajectory.
start_with_input: if True, output the trajectory at steps [0, ..., steps-1]
instead of steps [1, ..., steps].
Returns:
A function that takes an initial state and returns a tuple consisting of:
(1) the final frame of the trajectory _before_ `post_process` is applied.
(2) trajectory of length `steps` representing time evolution.
"""
# TODO(shoyer): change the default to start_with_input=True, once we're sure
# it works for training.
def step(carry_in, _):
carry_out = step_fn(carry_in)
frame = post_process(carry_in if start_with_input else carry_out)
return carry_out, frame
def multistep(values):
return scan(step, values, xs=None, length=steps)
return multistep
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.funcutils."""
from absl.testing import absltest
from absl.testing import parameterized
from jax_cfd.base import funcutils
from jax_cfd.base import test_util
import numpy as np
class TrajectoryTests(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='identity',
trajectory_length=6,
post_process=funcutils._identity),
dict(testcase_name='squared_postprocessing',
trajectory_length=12,
post_process=lambda x: (x[0] ** 2,)),
)
def test_trajectory(self, trajectory_length, post_process):
def step_fn(x):
return (x[0] + 1,)
trajectory_fn = funcutils.trajectory(
step_fn, trajectory_length, post_process)
initial_state = (2 * np.ones(1),)
expected_frames = []
frame = initial_state
for _ in range(trajectory_length):
frame = step_fn(frame)
expected_frames.append(post_process(frame))
expected_output = (np.stack([x[0] for x in expected_frames]),)
_, actual_output = trajectory_fn(initial_state)
for expected, actual in zip(expected_output, actual_output):
self.assertAllClose(expected, actual, atol=1e-9)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grid classes that contain discretization information and boundary conditions."""
from __future__ import annotations
import dataclasses
import numbers
import operator
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import jax
from jax import core
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
import numpy as np
# TODO(jamieas): consider moving common types to a separate module.
# TODO(shoyer): consider adding jnp.ndarray?
Array = Union[np.ndarray, jax.Array]
IntOrSequence = Union[int, Sequence[int]]
# There is currently no good way to indicate a jax "pytree" with arrays at its
# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more
# information about PyTrees and https://github.com/google/jax/issues/3340 for
# discussion of this issue.
PyTree = Any
@register_pytree_node_class
@dataclasses.dataclass
class GridArray(np.lib.mixins.NDArrayOperatorsMixin):
"""Data with an alignment offset and an associated grid.
Offset values in the range [0, 1] fall within a single grid cell.
Examples:
offset=(0, 0) means that each point is at the bottom-left corner.
offset=(0.5, 0.5) is at the grid center.
offset=(1, 0.5) is centered on the right-side edge.
Attributes:
data: array values.
offset: alignment location of the data with respect to the grid.
grid: the Grid associated with the array data.
dtype: type of the array data.
shape: lengths of the array dimensions.
"""
# Don't (yet) enforce any explicit consistency requirements between data.ndim
# and len(offset), e.g., so we can feel to add extra time/batch/channel
# dimensions. But in most cases they should probably match.
# Also don't enforce explicit consistency between data.shape and grid.shape,
# but similarly they should probably match.
data: Array
offset: Tuple[float, ...]
grid: Grid
def tree_flatten(self):
"""Returns flattening recipe for GridArray JAX pytree."""
children = (self.data,)
aux_data = (self.offset, self.grid)
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Returns unflattening recipe for GridArray JAX pytree."""
return cls(*children, *aux_data)
@property
def dtype(self):
return self.data.dtype
@property
def shape(self) -> Tuple[int, ...]:
return self.data.shape
_HANDLED_TYPES = (numbers.Number, np.ndarray, jax.Array, core.ShapedArray)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""Define arithmetic on GridArrays using NumPy's mixin."""
for x in inputs:
if not isinstance(x, self._HANDLED_TYPES + (GridArray,)):
return NotImplemented
if method != '__call__':
return NotImplemented
try:
# get the corresponding jax.np function to the NumPy ufunc
func = getattr(jnp, ufunc.__name__)
except AttributeError:
return NotImplemented
arrays = [x.data if isinstance(x, GridArray) else x for x in inputs]
result = func(*arrays)
offset = consistent_offset(*[x for x in inputs if isinstance(x, GridArray)])
grid = consistent_grid(*[x for x in inputs if isinstance(x, GridArray)])
if isinstance(result, tuple):
return tuple(GridArray(r, offset, grid) for r in result)
else:
return GridArray(result, offset, grid)
GridArrayVector = Tuple[GridArray, ...]
class GridArrayTensor(np.ndarray):
"""A numpy array of GridArrays, representing a physical tensor field.
Packing tensor coordinates into a numpy array of dtype object is useful
because pointwise matrix operations like trace, transpose, and matrix
multiplications of physical tensor quantities is meaningful.
Example usage:
grad = fd.gradient_tensor(uv) # a rank 2 Tensor
strain_rate = (grad + grad.T) / 2.
nu_smag = np.sqrt(np.trace(strain_rate.dot(strain_rate)))
nu_smag = Tensor(nu_smag) # a rank 0 Tensor
subgrid_stress = -2 * nu_smag * strain_rate # a rank 2 Tensor
"""
def __new__(cls, arrays):
return np.asarray(arrays).view(cls)
jax.tree_util.register_pytree_node(
GridArrayTensor,
lambda tensor: (tensor.ravel().tolist(), tensor.shape),
lambda shape, arrays: GridArrayTensor(np.asarray(arrays).reshape(shape)),
)
@dataclasses.dataclass(init=False, frozen=True)
class BoundaryConditions:
"""Base class for boundary conditions on a PDE variable.
Attributes:
types: `types[i]` is a tuple specifying the lower and upper BC types for
dimension `i`.
"""
types: Tuple[Tuple[str, str], ...]
def shift(
self,
u: GridArray,
offset: int,
axis: int,
mode: Optional[str] = 'extend',
) -> GridArray:
"""Shift an GridArray by `offset`.
Args:
u: an `GridArray` object.
offset: positive or negative integer offset to shift.
axis: axis to shift along.
mode: specifies how to extend past the boundary/ghost cells.
Valid options contained in boundaries.Padding.
Returns:
A copy of `u`, shifted by `offset`. The returned `GridArray` has offset
`u.offset + offset`.
"""
raise NotImplementedError(
'shift() not implemented in BoundaryConditions base class.')
def values(self, axis: int, grid: Grid, offset: Optional[Tuple[float, ...]],
time: Optional[float]) -> Tuple[Optional[Array], Optional[Array]]:
"""Returns Arrays specifying boundary values on the grid along axis.
Args:
axis: axis along which to return boundary values.
grid: a `Grid` object on which to evaluate boundary conditions.
offset: a Tuple of offsets that specifies (along with grid) where to
evaluate boundary conditions in space.
time: a float used as an input to boundary function.
Returns:
A tuple of arrays of grid.ndim - 1 dimensions that specify values on the
boundary. In case of periodic boundaries, returns a tuple(None,None).
"""
raise NotImplementedError(
'values() not implemented in BoundaryConditions base class.')
def pad(
self,
u: GridArray,
width: int,
axis: int,
mode: Optional[str] = 'extend',
) -> GridArray:
"""Returns Arrays padded according to boundary condition.
Args:
u: a `GridArray` object.
width: number of elements to pad along axis. Use negative value for lower
boundary or positive value for upper boundary.
axis: axis to pad along.
mode: specifies how to extend past the boundary/ghost cells.
Valid options contained in boundaries.Padding.
Returns:
A GridArray that is elongated along axis with padded values.
"""
raise NotImplementedError(
'pad() not implemented in BoundaryConditions base class.')
def trim_boundary(self, u: GridArray) -> GridArray:
"""Returns GridArray without the grid points on the boundary.
Some grid points of GridArray might coincide with boundary. This trims those
values.
Args:
u: a `GridArray` object.
Returns:
A GridArray shrunk along certain dimensions.
"""
raise NotImplementedError(
'trim_boundary() not implemented in BoundaryConditions base class.')
def pad_and_impose_bc(
self,
u: GridArray,
offset_to_pad_to: Optional[Tuple[float, ...]] = None) -> GridVariable:
"""Returns GridVariable with correct boundary condition.
Some grid points of GridArray might coincide with boundary. This ensures
that the GridVariable.array agrees with GridVariable.bc.
Args:
u: a `GridArray` object that specifies only scalar values on the internal
nodes.
offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the
function is given just an interior array in dirichlet case, it can pad
to both 0 offset and 1 offset.
Returns:
A GridVariable that has correct boundary.
"""
raise NotImplementedError(
'pad_and_impose_bc() not implemented in BoundaryConditions base class.')
def impose_bc(self, u: GridArray) -> GridVariable:
"""Returns GridVariable with correct boundary condition.
Some grid points of GridArray might coincide with boundary. This ensures
that the GridVariable.array agrees with GridVariable.bc.
Args:
u: a `GridArray` object.
Returns:
A GridVariable that has correct boundary.
"""
raise NotImplementedError(
'impose_bc() not implemented in BoundaryConditions base class.')
@register_pytree_node_class
@dataclasses.dataclass
class GridVariable:
"""Associates a GridArray with BoundaryConditions.
Performing pad and shift operations, e.g. for finite difference calculations,
requires boundary condition (BC) information. Since different variables in a
PDE system can have different BCs, this class associates a specific variable's
data with its BCs.
Array operations on GridVariables act like array operations on the
encapsulated GridArray.
Attributes:
array: GridArray with the array data, offset, and associated grid.
bc: boundary conditions for this variable.
grid: the Grid associated with the array data.
dtype: type of the array data.
shape: lengths of the array dimensions.
data: array values.
offset: alignment location of the data with respect to the grid.
grid: the Grid associated with the array data.
"""
array: GridArray
bc: BoundaryConditions
def __post_init__(self):
if not isinstance(self.array, GridArray): # frequently missed by pytype
raise ValueError(
f'Expected array type to be GridArray, got {type(self.array)}')
if len(self.bc.types) != self.grid.ndim:
raise ValueError(
'Incompatible dimension between grid and bc, grid dimension = '
f'{self.grid.ndim}, bc dimension = {len(self.bc.types)}')
def tree_flatten(self):
"""Returns flattening recipe for GridVariable JAX pytree."""
children = (self.array,)
aux_data = (self.bc,)
return children, aux_data
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Returns unflattening recipe for GridVariable JAX pytree."""
return cls(*children, *aux_data)
@property
def dtype(self):
return self.array.dtype
@property
def shape(self) -> Tuple[int, ...]:
return self.array.shape
@property
def data(self) -> Array:
return self.array.data
@property
def offset(self) -> Tuple[float, ...]:
return self.array.offset
@property
def grid(self) -> Grid:
return self.array.grid
def shift(
self,
offset: int,
axis: int,
mode: Optional[str] = 'extend',
) -> GridArray:
"""Shift this GridVariable by `offset`.
Args:
offset: positive or negative integer offset to shift.
axis: axis to shift along.
mode: specifies how to extend past the boundary/ghost cells.
Valid options contained in boundaries.Padding.
Returns:
A copy of the encapsulated GridArray, shifted by `offset`. The returned
GridArray has offset `u.offset + offset`.
"""
return self.bc.shift(self.array, offset, axis, mode)
def _interior_grid(self) -> Grid:
"""Returns only the interior grid points."""
grid = self.array.grid
domain = list(grid.domain)
shape = list(grid.shape)
for axis in range(self.grid.ndim):
# nothing happens in periodic case
if self.bc.types[axis][1] == 'periodic':
continue
# nothing happens if the offset is not 0.0 or 1.0
# this will automatically set the grid to interior.
if np.isclose(self.array.offset[axis], 1.0):
shape[axis] -= 1
domain[axis] = (domain[axis][0], domain[axis][1] - grid.step[axis])
elif np.isclose(self.array.offset[axis], 0.0):
shape[axis] -= 1
domain[axis] = (domain[axis][0] + grid.step[axis], domain[axis][1])
return Grid(shape, domain=tuple(domain))
def trim_boundary(self) -> GridArray:
"""Returns a GridArray associated only with interior points.
Interior is defined as the following:
for d in range(u.grid.ndim):
points = u.grid.axes(offset=u.offset[d])
interior_points =
all points where grid.domain[d][0] < points < grid.domain[d][1]
The exception is when the boundary conditions are periodic,
in which case all points are included in the interior.
In case of dirichlet with edge offset, the grid and array size is reduced,
since one scalar lies exactly on the boundary. In all other cases,
self.grid and self.array are returned.
"""
return self.bc.trim_boundary(self.array)
def impose_bc(self) -> GridVariable:
"""Returns the GridVariable with edge BC enforced, if applicable.
For GridVariables having nonperiodic BC and offset 0 or 1, there are values
in the array data that are dependent on the boundary condition.
impose_bc() changes these boundary values to match the prescribed BC.
"""
return self.bc.impose_bc(self.array)
GridVariableVector = Tuple[GridVariable, ...]
def applied(func):
"""Convert an array function into one defined on GridArrays.
Since `func` can only act on `data` attribute of GridArray, it implicitly
enforces that `func` cannot modify the other attributes such as offset.
Args:
func: function being wrapped.
Returns:
A wrapped version of `func` that takes GridArray instead of Array args.
"""
def wrapper(*args, **kwargs): # pylint: disable=missing-docstring
for arg in args + tuple(kwargs.values()):
if isinstance(arg, GridVariable):
raise ValueError('grids.applied() cannot be used with GridVariable')
offset = consistent_offset(*[
arg for arg in args + tuple(kwargs.values())
if isinstance(arg, GridArray)
])
grid = consistent_grid(*[
arg for arg in args + tuple(kwargs.values())
if isinstance(arg, GridArray)
])
raw_args = [arg.data if isinstance(arg, GridArray) else arg for arg in args]
raw_kwargs = {
k: v.data if isinstance(v, GridArray) else v for k, v in kwargs.items()
}
data = func(*raw_args, **raw_kwargs)
return GridArray(data, offset, grid)
return wrapper
# Aliases for often used `grids.applied` functions.
where = applied(jnp.where)
def averaged_offset(
*arrays: Union[GridArray, GridVariable]) -> Tuple[float, ...]:
"""Returns the averaged offset of the given arrays."""
offset = np.mean([array.offset for array in arrays], axis=0)
return tuple(offset.tolist())
def control_volume_offsets(
c: Union[GridArray, GridVariable]) -> Tuple[Tuple[float, ...], ...]:
"""Returns offsets for the faces of the control volume centered at `c`."""
return tuple(
tuple(o + .5 if i == j else o
for i, o in enumerate(c.offset))
for j in range(len(c.offset)))
class InconsistentOffsetError(Exception):
"""Raised for cases of inconsistent offset in GridArrays."""
def consistent_offset(
*arrays: Union[GridArray, GridVariable]) -> Tuple[float, ...]:
"""Returns the unique offset, or raises InconsistentOffsetError."""
offsets = {array.offset for array in arrays}
if len(offsets) != 1:
raise InconsistentOffsetError(
f'arrays do not have a unique offset: {offsets}')
offset, = offsets
return offset
class InconsistentGridError(Exception):
"""Raised for cases of inconsistent grids between GridArrays."""
def consistent_grid(*arrays: Union[GridArray, GridVariable]) -> Grid:
"""Returns the unique grid, or raises InconsistentGridError."""
grids = {array.grid for array in arrays}
if len(grids) != 1:
raise InconsistentGridError(f'arrays do not have a unique grid: {grids}')
grid, = grids
return grid
class InconsistentBoundaryConditionsError(Exception):
"""Raised for cases of inconsistent bc between GridVariables."""
def unique_boundary_conditions(*arrays: GridVariable) -> BoundaryConditions:
"""Returns the unique BCs, or raises InconsistentBoundaryConditionsError."""
bcs = {array.bc for array in arrays}
if len(bcs) != 1:
raise InconsistentBoundaryConditionsError(
f'arrays do not have a unique bc: {bcs}')
bc, = bcs
return bc
@dataclasses.dataclass(init=False, frozen=True)
class Grid:
"""Describes the size and shape for an Arakawa C-Grid.
See https://en.wikipedia.org/wiki/Arakawa_grids.
This class describes domains that can be written as an outer-product of 1D
grids. Along each dimension `i`:
- `shape[i]` gives the whole number of grid cells on a single device.
- `step[i]` is the width of each grid cell.
- `(lower, upper) = domain[i]` gives the locations of lower and upper
boundaries. The identity `upper - lower = step[i] * shape[i]` is enforced.
"""
shape: Tuple[int, ...]
step: Tuple[float, ...]
domain: Tuple[Tuple[float, float], ...]
def __init__(
self,
shape: Sequence[int],
step: Optional[Union[float, Sequence[float]]] = None,
domain: Optional[Union[float, Sequence[Tuple[float, float]]]] = None,
):
"""Construct a grid object."""
shape = tuple(operator.index(s) for s in shape)
object.__setattr__(self, 'shape', shape)
if step is not None and domain is not None:
raise TypeError('cannot provide both step and domain')
elif domain is not None:
if isinstance(domain, (int, float)):
domain = ((0, domain),) * len(shape)
else:
if len(domain) != self.ndim:
raise ValueError('length of domain does not match ndim: '
f'{len(domain)} != {self.ndim}')
for bounds in domain:
if len(bounds) != 2:
raise ValueError(
f'domain is not sequence of pairs of numbers: {domain}')
domain = tuple((float(lower), float(upper)) for lower, upper in domain)
else:
if step is None:
step = 1
if isinstance(step, numbers.Number):
step = (step,) * self.ndim
elif len(step) != self.ndim:
raise ValueError('length of step does not match ndim: '
f'{len(step)} != {self.ndim}')
domain = tuple(
(0.0, float(step_ * size)) for step_, size in zip(step, shape))
object.__setattr__(self, 'domain', domain)
step = tuple(
(upper - lower) / size for (lower, upper), size in zip(domain, shape))
object.__setattr__(self, 'step', step)
@property
def ndim(self) -> int:
"""Returns the number of dimensions of this grid."""
return len(self.shape)
@property
def cell_center(self) -> Tuple[float, ...]:
"""Offset at the center of each grid cell."""
return self.ndim * (0.5,)
@property
def cell_faces(self) -> Tuple[Tuple[float, ...]]:
"""Returns the offsets at each of the 'forward' cell faces."""
d = self.ndim
offsets = (np.eye(d) + np.ones([d, d])) / 2.
return tuple(tuple(float(o) for o in offset) for offset in offsets)
def stagger(self, v: Tuple[Array, ...]) -> Tuple[GridArray, ...]:
"""Places the velocity components of `v` on the `Grid`'s cell faces."""
offsets = self.cell_faces
return tuple(GridArray(u, o, self) for u, o in zip(v, offsets))
def center(self, v: PyTree) -> PyTree:
"""Places all arrays in the pytree `v` at the `Grid`'s cell center."""
offset = self.cell_center
return jax.tree_util.tree_map(lambda u: GridArray(u, offset, self), v)
def axes(self, offset: Optional[Sequence[float]] = None) -> Tuple[Array, ...]:
"""Returns a tuple of arrays containing the grid points along each axis.
Args:
offset: an optional sequence of length `ndim`. The grid will be shifted by
`offset * self.step`.
Returns:
An tuple of `self.ndim` arrays. The jth return value has shape
`[self.shape[j]]`.
"""
if offset is None:
offset = self.cell_center
if len(offset) != self.ndim:
raise ValueError(f'unexpected offset length: {len(offset)} vs '
f'{self.ndim}')
return tuple(lower + (jnp.arange(length) + offset_i) * step
for (lower, _), offset_i, length, step in zip(
self.domain, offset, self.shape, self.step))
def fft_axes(self) -> Tuple[Array, ...]:
"""Returns the ordinal frequencies corresponding to the axes.
Transforms each axis into the *ordinal* frequencies for the Fast Fourier
Transform (FFT). Multiply by `2 * jnp.pi` to get angular frequencies.
Returns:
A tuple of `self.ndim` arrays. The jth return value has shape
`[self.shape[j]]`.
"""
freq_axes = tuple(
jnp.fft.fftfreq(n, d=s) for (n, s) in zip(self.shape, self.step))
return freq_axes
def rfft_axes(self) -> Tuple[Array, ...]:
"""Returns the ordinal frequencies corresponding to the axes.
Transforms each axis into the *ordinal* frequencies for the Fast Fourier
Transform (FFT). Most useful for doing computations for real-valued (not
complex valued) signals.
Multiply by `2 * jnp.pi` to get angular frequencies.
Returns:
A tuple of `self.ndim` arrays. The shape of each array matches the result
of rfftfreqs. Specifically, rfft is applied to the last dimension
resulting in an array of length `self.shape[-1] // 2`. Complex `fft` is
applied to the other dimensions resulting in shapes of size
`self.shape[j]`.
"""
fft_axes = tuple(
jnp.fft.fftfreq(n, d=s)
for (n, s) in zip(self.shape[:-1], self.step[:-1]))
rfft_axis = (jnp.fft.rfftfreq(self.shape[-1], d=self.step[-1]),)
return fft_axes + rfft_axis
def mesh(self, offset: Optional[Sequence[float]] = None) -> Tuple[Array, ...]:
"""Returns an tuple of arrays containing positions in each grid cell.
Args:
offset: an optional sequence of length `ndim`. The grid will be shifted by
`offset * self.step`.
Returns:
An tuple of `self.ndim` arrays, each of shape `self.shape`. In 3
dimensions, entry `self.mesh[n][i, j, k]` is the location of point
`i, j, k` in dimension `n`.
"""
axes = self.axes(offset)
return tuple(jnp.meshgrid(*axes, indexing='ij'))
def rfft_mesh(self) -> Tuple[Array, ...]:
"""Returns a tuple of arrays containing positions in rfft space."""
rfft_axes = self.rfft_axes()
return tuple(jnp.meshgrid(*rfft_axes, indexing='ij'))
def eval_on_mesh(self,
fn: Callable[..., Array],
offset: Optional[Sequence[float]] = None) -> GridArray:
"""Evaluates the function on the grid mesh with the specified offset.
Args:
fn: A function that accepts the mesh arrays and returns an array.
offset: an optional sequence of length `ndim`. If not specified, uses the
offset for the cell center.
Returns:
fn(x, y, ...) evaluated on the mesh, as a GridArray with specified offset.
"""
if offset is None:
offset = self.cell_center
return GridArray(fn(*self.mesh(offset)), offset, self)
def domain_interior_masks(grid: Grid):
"""Returns cell face arrays with 1 on the interior, 0 on the boundary."""
masks = []
for offset in grid.cell_faces:
mesh = grid.mesh(offset)
mask = 1
for i, x in enumerate(mesh):
lower = (np.invert(np.isclose(x, grid.domain[i][0]))).astype('int')
upper = (np.invert(np.isclose(x, grid.domain[i][1]))).astype('int')
mask = mask * upper * lower
masks.append(mask)
return tuple(masks)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.grids."""
# TODO(jamieas): Consider updating these tests using the `hypothesis` framework.
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np
class GridArrayTest(test_util.TestCase):
def test_tree_util(self):
array = grids.GridArray(jnp.arange(3), offset=(0,), grid=grids.Grid((3,)))
flat, treedef = jax.tree_util.flatten(array)
roundtripped = jax.tree_util.unflatten(treedef, flat)
self.assertArrayEqual(array, roundtripped)
def test_consistent_offset(self):
data = jnp.arange(3)
grid = grids.Grid((3,))
array_offset_0 = grids.GridArray(data, offset=(0,), grid=grid)
array_offset_1 = grids.GridArray(data, offset=(1,), grid=grid)
offset = grids.consistent_offset(array_offset_0, array_offset_0)
self.assertEqual(offset, (0,))
with self.assertRaises(grids.InconsistentOffsetError):
grids.consistent_offset(array_offset_0, array_offset_1)
def test_averaged_offset(self):
data = jnp.arange(3)
grid = grids.Grid((3,))
array_offset_0 = grids.GridArray(data, offset=(0,), grid=grid)
array_offset_1 = grids.GridArray(data, offset=(1,), grid=grid)
averaged_offset = grids.averaged_offset(array_offset_0, array_offset_1)
self.assertEqual(averaged_offset, (0.5,))
def test_control_volume_offsets(self):
data = jnp.arange(5, 5)
grid = grids.Grid((5, 5))
array = grids.GridArray(data, offset=(0, 0), grid=grid)
cv_offset = grids.control_volume_offsets(array)
self.assertEqual(cv_offset, ((0.5, 0), (0, 0.5)))
def test_consistent_grid(self):
data = jnp.arange(3)
offset = (0,)
array_grid_3 = grids.GridArray(data, offset, grid=grids.Grid((3,)))
array_grid_5 = grids.GridArray(data, offset, grid=grids.Grid((5,)))
grid = grids.consistent_grid(array_grid_3, array_grid_3)
self.assertEqual(grid, grids.Grid((3,)))
with self.assertRaises(grids.InconsistentGridError):
grids.consistent_grid(array_grid_3, array_grid_5)
def test_add_sub_correctness(self):
values_1 = np.random.uniform(size=(5, 5))
values_2 = np.random.uniform(size=(5, 5))
offsets = (0.5, 0.5)
grid = grids.Grid((5, 5))
input_array_1 = grids.GridArray(values_1, offsets, grid)
input_array_2 = grids.GridArray(values_2, offsets, grid)
actual_sum = input_array_1 + input_array_2
actual_sub = input_array_1 - input_array_2
expected_sum = grids.GridArray(values_1 + values_2, offsets, grid)
expected_sub = grids.GridArray(values_1 - values_2, offsets, grid)
self.assertAllClose(actual_sum, expected_sum, atol=1e-7)
self.assertAllClose(actual_sub, expected_sub, atol=1e-7)
def test_add_sub_offset_raise(self):
values_1 = np.random.uniform(size=(5, 5))
values_2 = np.random.uniform(size=(5, 5))
offset_1 = (0.5, 0.5)
offset_2 = (0.5, 0.0)
grid = grids.Grid((5, 5))
input_array_1 = grids.GridArray(values_1, offset_1, grid)
input_array_2 = grids.GridArray(values_2, offset_2, grid)
with self.assertRaises(grids.InconsistentOffsetError):
_ = input_array_1 + input_array_2
with self.assertRaises(grids.InconsistentOffsetError):
_ = input_array_1 - input_array_2
def test_add_sub_grid_raise(self):
values_1 = np.random.uniform(size=(5, 5))
values_2 = np.random.uniform(size=(5, 5))
offset = (0.5, 0.5)
grid_1 = grids.Grid((5, 5), domain=((0, 1), (0, 1)))
grid_2 = grids.Grid((5, 5), domain=((-2, 2), (-2, 2)))
input_array_1 = grids.GridArray(values_1, offset, grid_1)
input_array_2 = grids.GridArray(values_2, offset, grid_2)
with self.assertRaises(grids.InconsistentGridError):
_ = input_array_1 + input_array_2
with self.assertRaises(grids.InconsistentGridError):
_ = input_array_1 - input_array_2
def test_mul_div_correctness(self):
values_1 = np.random.uniform(size=(5, 5))
values_2 = np.random.uniform(size=(5, 5))
scalar = 3.1415
offset = (0.5, 0.5)
grid = grids.Grid((5, 5))
input_array_1 = grids.GridArray(values_1, offset, grid)
input_array_2 = grids.GridArray(values_2, offset, grid)
actual_mul = input_array_1 * input_array_2
array_1_times_scalar = input_array_1 * scalar
expected_1_times_scalar = grids.GridArray(values_1 * scalar, offset, grid)
actual_div = input_array_1 / 2.5
expected_div = grids.GridArray(values_1 / 2.5, offset, grid)
expected_mul = grids.GridArray(values_1 * values_2, offset, grid)
self.assertAllClose(actual_mul, expected_mul, atol=1e-7)
self.assertAllClose(
array_1_times_scalar, expected_1_times_scalar, atol=1e-7)
self.assertAllClose(actual_div, expected_div, atol=1e-7)
def test_add_inplace(self):
values_1 = np.random.uniform(size=(5, 5))
values_2 = np.random.uniform(size=(5, 5))
offsets = (0.5, 0.5)
grid = grids.Grid((5, 5))
array = grids.GridArray(values_1, offsets, grid)
array += values_2
expected = grids.GridArray(values_1 + values_2, offsets, grid)
self.assertAllClose(array, expected, atol=1e-7)
def test_jit(self):
u = grids.GridArray(jnp.ones([10, 10]), (.5, .5), grids.Grid((10, 10)))
def f(u):
return u.data < 2.
self.assertAllClose(f(u), jax.jit(f)(u))
def test_applied(self):
grid = grids.Grid((10, 10))
offset = (0.5, 0.5)
u = grids.GridArray(jnp.ones([10, 10]), offset, grid)
expected = grids.GridArray(-jnp.ones([10, 10]), offset, grid)
actual = grids.applied(jnp.negative)(u)
self.assertAllClose(expected, actual)
class GridVariableTest(test_util.TestCase):
def test_constructor_and_attributes(self):
with self.subTest('1d'):
grid = grids.Grid((10,))
data = np.zeros((10,), dtype=np.float32)
array = grids.GridArray(data, offset=(0.5,), grid=grid)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
variable = grids.GridVariable(array, bc)
self.assertArrayEqual(variable.array, array)
self.assertEqual(variable.bc, bc)
self.assertEqual(variable.dtype, np.float32)
self.assertEqual(variable.shape, (10,))
self.assertArrayEqual(variable.data, data)
self.assertEqual(variable.offset, (0.5,))
self.assertEqual(variable.grid, grid)
with self.subTest('2d'):
grid = grids.Grid((10, 10))
data = np.zeros((10, 10), dtype=np.float32)
array = grids.GridArray(data, offset=(0.5, 0.5), grid=grid)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
variable = grids.GridVariable(array, bc)
self.assertArrayEqual(variable.array, array)
self.assertEqual(variable.bc, bc)
self.assertEqual(variable.dtype, np.float32)
self.assertEqual(variable.shape, (10, 10))
self.assertArrayEqual(variable.data, data)
self.assertEqual(variable.offset, (0.5, 0.5))
self.assertEqual(variable.grid, grid)
with self.subTest('batch dim data'):
grid = grids.Grid((10, 10))
data = np.zeros((5, 10, 10), dtype=np.float32)
array = grids.GridArray(data, offset=(0.5, 0.5), grid=grid)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
variable = grids.GridVariable(array, bc)
self.assertArrayEqual(variable.array, array)
self.assertEqual(variable.bc, bc)
self.assertEqual(variable.dtype, np.float32)
self.assertEqual(variable.shape, (5, 10, 10))
self.assertArrayEqual(variable.data, data)
self.assertEqual(variable.offset, (0.5, 0.5))
self.assertEqual(variable.grid, grid)
with self.subTest('raises exception'):
with self.assertRaisesRegex(ValueError,
'Incompatible dimension between grid and bc'):
grid = grids.Grid((10,))
data = np.zeros((10,))
array = grids.GridArray(data, offset=(0.5,), grid=grid) # 1D
bc = boundaries.periodic_boundary_conditions(ndim=2) # 2D
grids.GridVariable(array, bc)
@parameterized.parameters(
dict(
shape=(10,),
offset=(0.0,),
),
dict(
shape=(10,),
offset=(0.5,),
),
dict(
shape=(10,),
offset=(1.0,),
),
dict(
shape=(10, 10),
offset=(1.0, 0.0),
),
dict(
shape=(10, 10, 10),
offset=(1.0, 0.0, 0.5),
),
)
def test_interior_consistency_periodic(self, shape, offset):
grid = grids.Grid(shape)
data = np.random.randint(0, 10, shape)
array = grids.GridArray(data, offset=offset, grid=grid)
bc = boundaries.periodic_boundary_conditions(ndim=len(shape))
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
self.assertArrayEqual(u_interior, u.array)
@parameterized.parameters(
dict(
shape=(10,),
bc=boundaries.dirichlet_boundary_conditions(ndim=1),
),
dict(
shape=(10,),
bc=boundaries.neumann_boundary_conditions(ndim=1),
),
dict(
shape=(10, 10),
bc=boundaries.dirichlet_boundary_conditions(ndim=2),
),
dict(
shape=(10, 10),
bc=boundaries.neumann_boundary_conditions(ndim=2),
),
dict(
shape=(10, 10, 10),
bc=boundaries.dirichlet_boundary_conditions(ndim=3),
),
dict(
shape=(10, 10, 10),
bc=boundaries.neumann_boundary_conditions(ndim=3),
),
)
def test_interior_consistency_no_edge_offsets(self, bc, shape):
grid = grids.Grid(shape)
data = np.random.randint(0, 10, shape)
array = grids.GridArray(data, offset=(0.5,) * len(shape), grid=grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
self.assertArrayEqual(u_interior, u.array)
@parameterized.parameters(
dict(
shape=(10,),
bc=boundaries.neumann_boundary_conditions(ndim=1),
offset=(0.5,)),
dict(
shape=(10, 10),
bc=boundaries.neumann_boundary_conditions(ndim=2),
offset=(0.5, 0.5)),
dict(
shape=(10, 10, 10),
bc=boundaries.neumann_boundary_conditions(ndim=3),
offset=(0.5, 0.5, 0.5)),
)
def test_interior_consistency_neumann(self, shape, bc, offset):
grid = grids.Grid(shape)
data = np.random.randint(0, 10, shape)
array = grids.GridArray(data, offset=offset, grid=grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
self.assertArrayEqual(u_interior, u.array)
@parameterized.parameters(
dict(
shape=(10,),
bc=boundaries.dirichlet_boundary_conditions(ndim=1),
offset=(0.0,)),
dict(
shape=(10, 10),
bc=boundaries.dirichlet_boundary_conditions(ndim=2),
offset=(0.0, 0.0)),
dict(
shape=(10, 10, 10),
bc=boundaries.dirichlet_boundary_conditions(ndim=3),
offset=(0.0, 0.0, 0.0)),
)
def test_interior_consistency_edge_offsets_dirichlet(self, shape, bc, offset):
grid = grids.Grid(shape)
data = np.random.randint(0, 10, shape)
array = grids.GridArray(data, offset=offset, grid=grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
self.assertEqual(u_interior.offset,
tuple(offset + 1 for offset in u.array.offset))
self.assertEqual(u_interior.grid.ndim, u.array.grid.ndim)
self.assertEqual(u_interior.grid.step, u.array.grid.step)
def test_interior_dirichlet(self):
data = np.array([
[11, 12, 13, 14, 15],
[21, 22, 23, 24, 25],
[31, 32, 33, 34, 35],
[41, 42, 43, 44, 45],
])
grid = grids.Grid(shape=(4, 5), domain=((0, 1), (0, 1)))
bc = boundaries.dirichlet_boundary_conditions(ndim=2)
with self.subTest('offset=(1, 0.5)'):
offset = (1., 0.5)
array = grids.GridArray(data, offset, grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
answer = np.array([[11, 12, 13, 14, 15], [21, 22, 23, 24, 25],
[31, 32, 33, 34, 35]])
self.assertArrayEqual(u_interior.data, answer)
self.assertEqual(u_interior.offset, offset)
self.assertEqual(u.grid, grid)
with self.subTest('offset=(1, 1)'):
offset = (1., 1.)
array = grids.GridArray(data, offset, grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
answer = np.array([[11, 12, 13, 14], [21, 22, 23, 24], [31, 32, 33, 34]])
self.assertArrayEqual(u_interior.data, answer)
self.assertEqual(u_interior.grid, grid)
with self.subTest('offset=(0.0, 0.5)'):
offset = (0., 0.5)
array = grids.GridArray(data, offset, grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
answer = np.array([[21, 22, 23, 24, 25], [31, 32, 33, 34, 35],
[41, 42, 43, 44, 45]])
self.assertArrayEqual(u_interior.data, answer)
self.assertEqual(u_interior.grid, grid)
with self.subTest('offset=(0.0, 0.0)'):
offset = (0.0, 0.0)
array = grids.GridArray(data, offset, grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
answer = np.array([[22, 23, 24, 25], [32, 33, 34, 35], [42, 43, 44, 45]])
self.assertArrayEqual(u_interior.data, answer)
self.assertEqual(u_interior.grid, grid)
with self.subTest('offset=(0.5, 0.0)'):
offset = (0.5, 0.0)
array = grids.GridArray(data, offset, grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
answer = np.array([[12, 13, 14, 15], [22, 23, 24, 25], [32, 33, 34, 35],
[42, 43, 44, 45]])
self.assertArrayEqual(u_interior.data, answer)
self.assertEqual(u_interior.grid, grid)
# this is consistent for all offsets, not just edge and center.
with self.subTest('offset=(0.25, 0.75)'):
offset = (0.25, 0.75)
array = grids.GridArray(data, offset, grid)
u = grids.GridVariable(array, bc)
u_interior = u.trim_boundary()
self.assertArrayEqual(u_interior.data, data)
self.assertEqual(u_interior.grid, grid)
@parameterized.parameters(
dict(
shape=(10,),
bc=boundaries.periodic_boundary_conditions(ndim=1),
padding=(1, 1),
axis=0,
),
dict(
shape=(10, 10),
bc=boundaries.dirichlet_boundary_conditions(ndim=2),
padding=(2, 1),
axis=1,
),
dict(
shape=(10, 10, 10),
bc=boundaries.neumann_boundary_conditions(ndim=3),
padding=(0, 2),
axis=2,
),
)
def test_shift_pad_trim(self, shape, bc, padding, axis):
grid = grids.Grid(shape)
data = np.random.randint(0, 10, shape)
array = grids.GridArray(data, offset=(0.5,) * len(shape), grid=grid)
u = grids.GridVariable(array, bc)
with self.subTest('shift'):
self.assertArrayEqual(
u.shift(offset=1, axis=axis), bc.shift(array, 1, axis))
with self.subTest('raises exception'):
with self.assertRaisesRegex(ValueError,
'Incompatible dimension between grid and bc'):
grid = grids.Grid((10,))
data = np.zeros((10,))
array = grids.GridArray(data, offset=(0.5,), grid=grid) # 1D
bc = boundaries.periodic_boundary_conditions(ndim=2) # 2D
grids.GridVariable(array, bc)
def test_unique_boundary_conditions(self):
grid = grids.Grid((5,))
array = grids.GridArray(np.arange(5), offset=(0.5,), grid=grid)
bc1 = boundaries.periodic_boundary_conditions(grid.ndim)
bc2 = boundaries.dirichlet_boundary_conditions(grid.ndim)
x_bc1 = grids.GridVariable(array, bc1)
y_bc1 = grids.GridVariable(array, bc1)
z_bc2 = grids.GridVariable(array, bc2)
bc = grids.unique_boundary_conditions(x_bc1, y_bc1)
self.assertEqual(bc, bc1)
with self.assertRaises(grids.InconsistentBoundaryConditionsError):
grids.unique_boundary_conditions(x_bc1, y_bc1, z_bc2)
class GridArrayTensorTest(test_util.TestCase):
def test_tensor_transpose(self):
grid = grids.Grid((5, 5))
offset = (0.5, 0.5)
a = grids.GridArray(1 * jnp.ones([5, 5]), offset, grid)
b = grids.GridArray(2 * jnp.ones([5, 5]), offset, grid)
c = grids.GridArray(3 * jnp.ones([5, 5]), offset, grid)
d = grids.GridArray(4 * jnp.ones([5, 5]), offset, grid)
tensor = grids.GridArrayTensor([[a, b], [c, d]])
self.assertIsInstance(tensor, np.ndarray)
transposed_tensor = np.transpose(tensor)
self.assertAllClose(tensor[0, 1], transposed_tensor[1, 0])
class GridTest(test_util.TestCase):
def test_constructor_and_attributes(self):
with self.subTest('1d'):
grid = grids.Grid((10,))
self.assertEqual(grid.shape, (10,))
self.assertEqual(grid.step, (1.0,))
self.assertEqual(grid.domain, ((0, 10.),))
self.assertEqual(grid.ndim, 1)
self.assertEqual(grid.cell_center, (0.5,))
self.assertEqual(grid.cell_faces, ((1.0,),))
with self.subTest('1d domain scalar size'):
grid = grids.Grid((10,), domain=10)
self.assertEqual(grid.domain, ((0.0, 10.0),))
with self.subTest('2d'):
grid = grids.Grid(
(10, 10),
step=0.1,
)
self.assertEqual(grid.step, (0.1, 0.1))
self.assertEqual(grid.domain, ((0, 1.0), (0, 1.0)))
self.assertEqual(grid.ndim, 2)
self.assertEqual(grid.cell_center, (0.5, 0.5))
self.assertEqual(grid.cell_faces, ((1.0, 0.5), (0.5, 1.0)))
with self.subTest('3d'):
grid = grids.Grid((10, 10, 10), step=(0.1, 0.2, 0.5))
self.assertEqual(grid.step, (0.1, 0.2, 0.5))
self.assertEqual(grid.domain, ((0, 1.0), (0, 2.0), (0, 5.0)))
self.assertEqual(grid.ndim, 3)
self.assertEqual(grid.cell_center, (0.5, 0.5, 0.5))
self.assertEqual(grid.cell_faces,
((1.0, 0.5, 0.5), (0.5, 1.0, 0.5), (0.5, 0.5, 1.0)))
with self.subTest('1d domain'):
grid = grids.Grid((10,), domain=[(-2, 2)])
self.assertEqual(grid.step, (2 / 5,))
self.assertEqual(grid.domain, ((-2., 2.),))
self.assertEqual(grid.ndim, 1)
self.assertEqual(grid.cell_center, (0.5,))
self.assertEqual(grid.cell_faces, ((1.0,),))
with self.subTest('2d domain'):
grid = grids.Grid((10, 20), domain=[(-2, 2), (0, 3)])
self.assertEqual(grid.step, (4 / 10, 3 / 20))
self.assertEqual(grid.domain, ((-2., 2.), (0., 3.)))
self.assertEqual(grid.ndim, 2)
self.assertEqual(grid.cell_center, (0.5, 0.5))
self.assertEqual(grid.cell_faces, ((1.0, 0.5), (0.5, 1.0)))
with self.subTest('2d periodic'):
grid = grids.Grid((10, 20), domain=2 * np.pi)
self.assertEqual(grid.step, (2 * np.pi / 10, 2 * np.pi / 20))
self.assertEqual(grid.domain, ((0., 2 * np.pi), (0., 2 * np.pi)))
self.assertEqual(grid.ndim, 2)
with self.assertRaisesRegex(TypeError, 'cannot provide both'):
grids.Grid((2,), step=(1.0,), domain=[(0, 2.0)])
with self.assertRaisesRegex(ValueError, 'length of domain'):
grids.Grid((2, 3), domain=[(0, 1)])
with self.assertRaisesRegex(ValueError, 'pairs of numbers'):
grids.Grid((2,), domain=[(0, 1, 2)])
with self.assertRaisesRegex(ValueError, 'length of step'):
grids.Grid((2, 3), step=(1.0,))
def test_stagger(self):
grid = grids.Grid((10, 10))
array_1 = jnp.zeros((10, 10))
array_2 = jnp.ones((10, 10))
u, v = grid.stagger((array_1, array_2))
self.assertEqual(u.offset, (1.0, 0.5))
self.assertEqual(v.offset, (0.5, 1.0))
def test_center(self):
grid = grids.Grid((10, 10))
with self.subTest('array ndim same as grid'):
array_1 = jnp.zeros((10, 10))
array_2 = jnp.zeros((20, 30))
v = (array_1, array_2) # tuple is a simple pytree
v_centered = grid.center(v)
self.assertLen(v_centered, 2)
self.assertIsInstance(v_centered[0], grids.GridArray)
self.assertIsInstance(v_centered[1], grids.GridArray)
self.assertEqual(v_centered[0].shape, (10, 10))
self.assertEqual(v_centered[1].shape, (20, 30))
self.assertEqual(v_centered[0].offset, (0.5, 0.5))
self.assertEqual(v_centered[1].offset, (0.5, 0.5))
with self.subTest('array ndim different than grid'):
# Assigns offset dimension based on grid.ndim
array_1 = jnp.zeros((10,))
array_2 = jnp.ones((10, 10, 10))
v = (array_1, array_2) # tuple is a simple pytree
v_centered = grid.center(v)
self.assertLen(v_centered, 2)
self.assertIsInstance(v_centered[0], grids.GridArray)
self.assertIsInstance(v_centered[1], grids.GridArray)
self.assertEqual(v_centered[0].shape, (10,))
self.assertEqual(v_centered[1].shape, (10, 10, 10))
self.assertEqual(v_centered[0].offset, (0.5, 0.5))
self.assertEqual(v_centered[1].offset, (0.5, 0.5))
def test_axes_and_mesh(self):
with self.subTest('1d'):
grid = grids.Grid((5,), step=0.1)
axes = grid.axes()
self.assertLen(axes, 1)
self.assertAllClose(axes[0], [0.05, 0.15, 0.25, 0.35, 0.45])
mesh = grid.mesh()
self.assertLen(mesh, 1)
self.assertAllClose(axes[0], mesh[0]) # in 1d, mesh matches array
with self.subTest('1d with offset'):
grid = grids.Grid((5,), step=0.1)
axes = grid.axes(offset=(0,))
self.assertLen(axes, 1)
self.assertAllClose(axes[0], [0.0, 0.1, 0.2, 0.3, 0.4])
mesh = grid.mesh(offset=(0,))
self.assertLen(mesh, 1)
self.assertAllClose(axes[0], mesh[0]) # in 1d, mesh matches array
with self.subTest('2d'):
grid = grids.Grid((4, 6), domain=[(-2, 2), (0, 3)])
axes = grid.axes()
self.assertLen(axes, 2)
self.assertAllClose(axes[0], [-1.5, -0.5, 0.5, 1.5])
self.assertAllClose(axes[1], [0.25, 0.75, 1.25, 1.75, 2.25, 2.75])
mesh = grid.mesh()
self.assertLen(mesh, 2)
self.assertEqual(mesh[0].shape, (4, 6))
self.assertEqual(mesh[1].shape, (4, 6))
self.assertAllClose(mesh[0][:, 0], axes[0])
self.assertAllClose(mesh[1][0, :], axes[1])
with self.subTest('2d with offset'):
grid = grids.Grid((4, 6), domain=[(-2, 2), (0, 3)])
axes = grid.axes(offset=(0, 1))
self.assertLen(axes, 2)
self.assertAllClose(axes[0], [-2.0, -1.0, 0.0, 1.0])
self.assertAllClose(axes[1], [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
mesh = grid.mesh(offset=(0, 1))
self.assertLen(mesh, 2)
self.assertEqual(mesh[0].shape, (4, 6))
self.assertEqual(mesh[1].shape, (4, 6))
self.assertAllClose(mesh[0][:, 0], axes[0])
self.assertAllClose(mesh[1][0, :], axes[1])
@parameterized.parameters(
dict(
shape=(10,),
fn=lambda x: 2 * np.ones_like(x),
offset=None,
expected_array=2 * np.ones((10,)),
expected_offset=(0.5,)),
dict(
shape=(10, 10),
fn=lambda x, y: np.ones_like(x) + np.ones_like(y),
offset=(1, 0.5),
expected_array=2 * np.ones((10, 10)),
expected_offset=(1, 0.5)),
dict(
shape=(10, 10, 10),
fn=lambda x, y, z: np.ones_like(z),
offset=None,
expected_array=np.ones((10, 10, 10)),
expected_offset=(0.5, 0.5, 0.5)),
)
def test_eval_on_mesh_default_offset(self, shape, fn, offset, expected_array,
expected_offset):
grid = grids.Grid(shape, step=0.1)
expected = grids.GridArray(expected_array, expected_offset, grid)
actual = grid.eval_on_mesh(fn, offset)
self.assertArrayEqual(expected, actual)
def test_spectral_axes(self):
length = 42.
shape = (64,)
grid = grids.Grid(shape, domain=((0, length),))
xs, = grid.axes()
fft_xs, = grid.fft_axes()
fft_xs *= 2 * jnp.pi # convert ordinal to angular frequencies
# compare the derivative of the sine function (i.e. cosine) with its
# derivative computed in frequency-space. Note that this derivative involves
# the computed frequencies so it can serve as a test.
angular_freq = 2 * jnp.pi / length
ys = jnp.sin(angular_freq * xs)
expected = angular_freq * jnp.cos(angular_freq * xs)
actual = jnp.fft.ifft(1j * fft_xs * jnp.fft.fft(ys))
self.assertAllClose(expected, actual, atol=1e-4)
def test_real_spectral_axes_1d(self):
length = 42.
shape = (64,)
grid = grids.Grid(shape, domain=((0, length),))
xs, = grid.axes()
fft_xs, = grid.rfft_axes()
fft_xs *= 2 * jnp.pi # convert ordinal to angular frequencies
# compare the derivative of the sine function (i.e. cosine) with its
# derivative computed in frequency-space. Note that this derivative involves
# the computed frequencies so it can serve as a test.
angular_freq = 2 * jnp.pi / length
ys = jnp.sin(angular_freq * xs)
expected = angular_freq * jnp.cos(angular_freq * xs)
actual = jnp.fft.irfft(1j * fft_xs * jnp.fft.rfft(ys))
self.assertAllClose(expected, actual, atol=1e-4)
def test_real_spectral_axes_nd_shape(self):
dim = 3
grid_size = 64
shape = (grid_size,) * dim
domain = ((0, 2 * jnp.pi),) * dim
grid = grids.Grid(shape, domain=(domain))
xs1, xs2, xs3 = grid.rfft_axes()
self.assertEqual(len(xs1), grid_size)
self.assertEqual(len(xs2), grid_size)
self.assertEqual(len(xs3), grid_size // 2 + 1)
def test_domain_interior_masks(self):
with self.subTest('1d'):
grid = grids.Grid((5,))
expected = [[1, 1, 1, 1, 0]]
self.assertAllClose(grids.domain_interior_masks(grid), expected)
with self.subTest('2d'):
grid = grids.Grid((3, 3))
expected = ([[1, 1, 1], [1, 1, 1], [0, 0, 0]], [[1, 1, 0], [1, 1, 0],
[1, 1, 0]])
self.assertAllClose(grids.domain_interior_masks(grid), expected)
with self.subTest('3d'):
grid = grids.Grid((3, 4, 5))
actual = grids.domain_interior_masks(grid)
self.assertLen(actual, 3)
# masks are zero on the outer edge, 1 on the interior
self.assertAllClose(actual[0][:-1, :, :], 1)
self.assertAllClose(actual[0][-1, :, :], 0)
self.assertAllClose(actual[1][:, :-1, :], 1)
self.assertAllClose(actual[1][:, -1, :], 0)
self.assertAllClose(actual[2][:, :, :-1], 1)
self.assertAllClose(actual[2][:, :, -1], 0)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare initial conditions for simulations."""
import functools
from typing import Callable, Optional, Sequence, Tuple, Union
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import filter_utils
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import pressure
import numpy as np
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
Array = Union[np.ndarray, jax.Array]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions
def wrap_variables(
var: Sequence[Array],
grid: grids.Grid,
bcs: Sequence[BoundaryConditions],
offsets: Optional[Sequence[Tuple[float, ...]]] = None,
batch_dim: bool = False,
) -> GridVariableVector:
"""Associates offsets, grid, and boundary conditions with a sequence of arrays."""
if offsets is None:
offsets = grid.cell_faces
def impose_bc(arrays):
return tuple(
bc.impose_bc(grids.GridArray(u, offset, grid))
for u, offset, bc in zip(arrays, offsets, bcs))
if batch_dim:
return jax.vmap(impose_bc)(var)
else:
return impose_bc(var)
def _log_normal_pdf(x, mode, variance=.25):
"""Unscaled PDF for a log normal given `mode` and log variance 1."""
mean = jnp.log(mode) + variance
logx = jnp.log(x)
return jnp.exp(-(mean - logx)**2 / 2 / variance - logx)
def _max_speed(v):
return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max()
def filtered_velocity_field(
rng_key: grids.Array,
grid: grids.Grid,
maximum_velocity: float = 1,
peak_wavenumber: float = 3,
iterations: int = 3,
) -> GridVariableVector:
"""Create divergence-free velocity fields with appropriate spectral filtering.
Args:
rng_key: key for seeding the random initial velocity field.
grid: the grid on which the velocity field is defined.
maximum_velocity: the maximum speed in the velocity field.
peak_wavenumber: the velocity field will be filtered so that the largest
magnitudes are associated with this wavenumber.
iterations: the number of repeated pressure projection and normalization
iterations to apply.
Returns:
A divergence free velocity field with the given maximum velocity. Associates
periodic boundary conditions with the velocity field components.
"""
# Log normal distribution peaked at `peak_wavenumber`. Note that we have to
# divide by `k ** (ndim - 1)` to account for the volume of the
# `ndim - 1`-sphere of values with wavenumber `k`.
def spectral_density(k):
return _log_normal_pdf(k, peak_wavenumber) / k ** (grid.ndim - 1)
# TODO(b/156601712): Switch back to the vmapped implementation of filtering:
# noise = jax.random.normal(rng_key, (grid.ndim,) + grid.shape)
# filtered = wrap_velocities(jax.vmap(spectral.filter, (None, 0, None))(
# spectral_density, noise, grid), grid)
keys = jax.random.split(rng_key, grid.ndim)
velocity_components = []
boundary_conditions = []
for k in keys:
noise = jax.random.normal(k, grid.shape)
velocity_components.append(
filter_utils.filter(spectral_density, noise, grid))
boundary_conditions.append(
boundaries.periodic_boundary_conditions(grid.ndim))
velocity = wrap_variables(velocity_components, grid, boundary_conditions)
def project_and_normalize(v: GridVariableVector):
v = pressure.projection(v)
vmax = _max_speed(v)
v = tuple(
grids.GridVariable(maximum_velocity * u.array / vmax, u.bc) for u in v)
return v
# Due to numerical precision issues, we repeatedly normalize and project the
# velocity field. This ensures that it is divergence-free and achieves the
# specified maximum velocity.
return funcutils.repeated(project_and_normalize, iterations)(velocity)
def initial_velocity_field(
velocity_fns: Tuple[Callable[..., Array], ...],
grid: grids.Grid,
velocity_bc: Optional[Sequence[BoundaryConditions]] = None,
pressure_solve: Callable = pressure.solve_fast_diag,
iterations: Optional[int] = None,
) -> GridVariableVector:
"""Given velocity functions on arrays, returns the velocity field on the grid.
Typical usage example:
grid = cfd.grids.Grid((128, 128))
x_velocity_fn = lambda x, y: jnp.sin(x) * jnp.cos(y)
y_velocity_fn = lambda x, y: jnp.zeros_like(x)
v0 = initial_velocity_field((x_velocity_fn, y_velocity_fn), grid, 5)
Args:
velocity_fns: functions for computing each velocity component. These should
takes the args (x, y, ...) and return an array of the same shape.
grid: the grid on which the velocity field is defined.
velocity_bc: the boundary conditions to associate with each velocity
component. If unspecified, uses periodic boundary conditions.
pressure_solve: method used to solve pressure projection.
iterations: if specified, the number of iterations of applied projection
onto an incompressible velocity field.
Returns:
Velocity components defined with expected offsets on the grid.
"""
if velocity_bc is None:
velocity_bc = (
boundaries.periodic_boundary_conditions(grid.ndim),) * grid.ndim
v = tuple(
grids.GridVariable(grid.eval_on_mesh(v_fn, offset), bc) for v_fn, offset,
bc in zip(velocity_fns, grid.cell_faces, velocity_bc))
if iterations is not None:
projection = functools.partial(pressure.projection, solve=pressure_solve)
v = funcutils.repeated(projection, iterations)(v)
return v
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.initial_conditions."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import initial_conditions as ic
from jax_cfd.base import pressure
from jax_cfd.base import test_util
import numpy as np
def get_grid(grid_size, ndim, domain_size_multiple=1):
domain = ((0, 2 * np.pi * domain_size_multiple),) * ndim
shape = (grid_size,) * ndim
return grids.Grid(shape=shape, domain=domain)
class InitialConditionsTest(test_util.TestCase):
@parameterized.parameters(
dict(seed=3232,
grid=get_grid(128, ndim=3),
maximum_velocity=1.,
peak_wavenumber=2),
dict(seed=2323,
grid=get_grid(1024, ndim=2),
maximum_velocity=10.,
peak_wavenumber=17),
)
def test_filtered_velocity_field(
self, seed, grid, maximum_velocity, peak_wavenumber):
v = ic.filtered_velocity_field(
jax.random.PRNGKey(seed), grid, maximum_velocity, peak_wavenumber)
actual_maximum_velocity = jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max()
max_divergence = fd.divergence(v).data.max()
# Assert that initial velocity is divergence free
self.assertAllClose(0., max_divergence, atol=5e-4)
# Assert that the specified maximum velocity is obtained.
self.assertAllClose(maximum_velocity, actual_maximum_velocity, atol=1e-5)
def test_initial_velocity_field_no_projection(self):
# Test on an already incompressible velocity field
grid = grids.Grid((10, 10), step=1.0)
x_velocity_fn = lambda x, y: jnp.ones_like(x)
y_velocity_fn = lambda x, y: jnp.zeros_like(x)
v0 = ic.initial_velocity_field((x_velocity_fn, y_velocity_fn), grid)
expected_v0 = (
grids.GridVariable(
grids.GridArray(jnp.ones((10, 10)), (1, 0.5), grid),
boundaries.periodic_boundary_conditions(grid.ndim)),
grids.GridVariable(
grids.GridArray(jnp.zeros((10, 10)), (0.5, 1), grid),
boundaries.periodic_boundary_conditions(grid.ndim)),
)
for d in range(len(v0)):
self.assertArrayEqual(expected_v0[d], v0[d])
with self.subTest('correction does not change answer'):
v0_corrected = ic.initial_velocity_field(
(x_velocity_fn, y_velocity_fn), grid, iterations=5)
for d in range(len(v0)):
self.assertIsInstance(v0_corrected[d], grids.GridVariable)
self.assertArrayEqual(expected_v0[d], v0_corrected[d])
@parameterized.parameters(
dict(
velocity_bc=(boundaries.dirichlet_boundary_conditions(2),
boundaries.dirichlet_boundary_conditions(2)),
pressure_solve=pressure.solve_cg,
),
dict(
velocity_bc=(boundaries.channel_flow_boundary_conditions(2),
boundaries.channel_flow_boundary_conditions(2)),
pressure_solve=pressure.solve_cg,
),
dict(
velocity_bc=(boundaries.channel_flow_boundary_conditions(2),
boundaries.channel_flow_boundary_conditions(2)),
pressure_solve=pressure.solve_fast_diag_channel_flow,
),
dict(velocity_bc=None, # default is all periodic BC.
pressure_solve=pressure.solve_fast_diag,
),
)
def test_initial_velocity_field_with_projection(self, velocity_bc,
pressure_solve):
grid = grids.Grid((20, 20), step=0.1)
# Use a mask to make the random noise zero on the boundaries, consistent
# with Dirichlet BC (and still valid for periodic BC).
masks = grids.domain_interior_masks(grid)
def x_velocity_fn(x, y):
return jnp.zeros_like(x + y) + 0.2 * np.random.normal(
size=grid.shape) * masks[0]
def y_velocity_fn(x, y):
return jnp.zeros_like(x + y) + 0.2 * np.random.normal(
size=grid.shape) * masks[1]
with self.subTest('corrected'):
v0_corrected = ic.initial_velocity_field((x_velocity_fn, y_velocity_fn),
grid,
velocity_bc,
pressure_solve,
iterations=5)
self.assertAllClose(fd.divergence(v0_corrected).data, 0, atol=1e-5)
with self.subTest('not corrected'):
v0_uncorrected = ic.initial_velocity_field((x_velocity_fn, y_velocity_fn),
grid,
velocity_bc,
pressure_solve,
iterations=None)
self.assertGreater(abs(fd.divergence(v0_uncorrected).data).max(), 0.1)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for interpolating `GridVariables`s."""
from typing import Callable, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
import numpy as np
Array = Union[np.ndarray, jax.Array]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = Callable[
[GridVariable, Tuple[float, ...], GridVariableVector, float],
GridVariable]
FluxLimiter = Callable[[grids.Array], grids.Array]
def _linear_along_axis(c: GridVariable,
offset: float,
axis: int) -> GridVariable:
"""Linear interpolation of `c` to `offset` along a single specified `axis`."""
offset_delta = offset - c.offset[axis]
# If offsets are the same, `c` is unchanged.
if offset_delta == 0:
return c
new_offset = tuple(offset if j == axis else o
for j, o in enumerate(c.offset))
# If offsets differ by an integer, we can just shift `c`.
if int(offset_delta) == offset_delta:
return grids.GridVariable(
array=grids.GridArray(data=c.shift(int(offset_delta), axis).data,
offset=new_offset,
grid=c.grid),
bc=c.bc)
floor = int(np.floor(offset_delta))
ceil = int(np.ceil(offset_delta))
floor_weight = ceil - offset_delta
ceil_weight = 1. - floor_weight
data = (floor_weight * c.shift(floor, axis).data +
ceil_weight * c.shift(ceil, axis).data)
return grids.GridVariable(
array=grids.GridArray(data, new_offset, c.grid), bc=c.bc)
def linear(
c: GridVariable,
offset: Tuple[float, ...],
v: Optional[GridVariableVector] = None,
dt: Optional[float] = None
) -> grids.GridVariable:
"""Multi-linear interpolation of `c` to `offset`.
Args:
c: quantitity to be interpolated.
offset: offset to which we will interpolate `c`. Must have the same length
as `c.offset`.
v: velocity field. Not used.
dt: size of the time step. Not used.
Returns:
An `GridArray` containing the values of `c` after linear interpolation
to `offset`. The returned value will have offset equal to `offset`.
"""
del v, dt # unused
if len(offset) != len(c.offset):
raise ValueError('`c.offset` and `offset` must have the same length;'
f'got {c.offset} and {offset}.')
interpolated = c
for a, o in enumerate(offset):
interpolated = _linear_along_axis(interpolated, offset=o, axis=a)
return interpolated
def upwind(
c: GridVariable,
offset: Tuple[float, ...],
v: GridVariableVector,
dt: Optional[float] = None
) -> GridVariable:
"""Upwind interpolation of `c` to `offset` based on velocity field `v`.
Interpolates values of `c` to `offset` in two steps:
1) Identifies the axis along which `c` is interpolated. (must be single axis)
2) For positive (negative) velocity along interpolation axis uses value from
the previous (next) cell along that axis correspondingly.
Args:
c: quantitity to be interpolated.
offset: offset to which `c` will be interpolated. Must have the same
length as `c.offset` and differ in at most one entry.
v: velocity field with offsets at faces of `c`. One of the components
must have the same offset as `offset`.
dt: size of the time step. Not used.
Returns:
A `GridVariable` that containins the values of `c` after interpolation to
`offset`.
Raises:
InconsistentOffsetError: if `offset` and `c.offset` differ in more than one
entry.
"""
del dt # unused
if c.offset == offset: return c
interpolation_axes = tuple(
axis for axis, (current, target) in enumerate(zip(c.offset, offset))
if current != target
)
if len(interpolation_axes) != 1:
raise grids.InconsistentOffsetError(
f'for upwind interpolation `c.offset` and `offset` must differ at most '
f'in one entry, but got: {c.offset} and {offset}.')
axis, = interpolation_axes
u = v[axis]
offset_delta = u.offset[axis] - c.offset[axis]
# If offsets differ by an integer, we can just shift `c`.
if int(offset_delta) == offset_delta:
return grids.GridVariable(
array=grids.GridArray(data=c.shift(int(offset_delta), axis).data,
offset=offset,
grid=grids.consistent_grid(c, u)),
bc=c.bc)
floor = int(np.floor(offset_delta))
ceil = int(np.ceil(offset_delta))
array = grids.applied(jnp.where)(
u.array > 0, c.shift(floor, axis).data, c.shift(ceil, axis).data
)
grid = grids.consistent_grid(c, u)
return grids.GridVariable(
array=grids.GridArray(array.data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
def lax_wendroff(
c: GridVariable,
offset: Tuple[float, ...],
v: Optional[GridVariableVector] = None,
dt: Optional[float] = None
) -> GridVariable:
"""Lax_Wendroff interpolation of `c` to `offset` based on velocity field `v`.
Interpolates values of `c` to `offset` in two steps:
1) Identifies the axis along which `c` is interpolated. (must be single axis)
2) For positive (negative) velocity along interpolation axis uses value from
the previous (next) cell along that axis plus a correction originating
from expansion of the solution at the half step-size.
This method is second order accurate with fixed coefficients and hence can't
be monotonic due to Godunov's theorem.
https://en.wikipedia.org/wiki/Godunov%27s_theorem
Lax-Wendroff method can be used to form monotonic schemes when augmented with
a flux limiter. See https://en.wikipedia.org/wiki/Flux_limiter
Args:
c: quantitity to be interpolated.
offset: offset to which we will interpolate `c`. Must have the same
length as `c.offset` and differ in at most one entry.
v: velocity field with offsets at faces of `c`. One of the components must
have the same offset as `offset`.
dt: size of the time step. Not used.
Returns:
A `GridVariable` that containins the values of `c` after interpolation to
`offset`.
Raises:
InconsistentOffsetError: if `offset` and `c.offset` differ in more than one
entry.
"""
# TODO(dkochkov) add a function to compute interpolation axis.
if c.offset == offset: return c
interpolation_axes = tuple(
axis for axis, (current, target) in enumerate(zip(c.offset, offset))
if current != target
)
if len(interpolation_axes) != 1:
raise grids.InconsistentOffsetError(
f'for Lax-Wendroff interpolation `c.offset` and `offset` must differ at'
f' most in one entry, but got: {c.offset} and {offset}.')
axis, = interpolation_axes
u = v[axis]
offset_delta = u.offset[axis] - c.offset[axis]
floor = int(np.floor(offset_delta)) # used for positive velocity
ceil = int(np.ceil(offset_delta)) # used for negative velocity
grid = grids.consistent_grid(c, u)
courant_numbers = (dt / grid.step[axis]) * u.data
positive_u_case = (
c.shift(floor, axis).data + 0.5 * (1 - courant_numbers) *
(c.shift(ceil, axis).data - c.shift(floor, axis).data))
negative_u_case = (
c.shift(ceil, axis).data - 0.5 * (1 + courant_numbers) *
(c.shift(ceil, axis).data - c.shift(floor, axis).data))
array = grids.where(u.array > 0, positive_u_case, negative_u_case)
grid = grids.consistent_grid(c, u)
return grids.GridVariable(
array=grids.GridArray(array.data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
def safe_div(x, y, default_numerator=1):
"""Safe division of `Array`'s."""
return x / jnp.where(y != 0, y, default_numerator)
def van_leer_limiter(r):
"""Van-leer flux limiter."""
return jnp.where(r > 0, safe_div(2 * r, 1 + r), 0.0)
def apply_tvd_limiter(
interpolation_fn: InterpolationFn,
limiter: FluxLimiter = van_leer_limiter
) -> InterpolationFn:
"""Combines low and high accuracy interpolators to get TVD method.
Generates high accuracy interpolator by combining stable lwo accuracy `upwind`
interpolation and high accuracy (but not guaranteed to be stable)
`interpolation_fn` to obtain stable higher order method. This implementation
follows the procedure outined in:
http://www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf
Args:
interpolation_fn: higher order interpolation methods. Must follow the same
interface as other interpolation methods (take `c`, `offset`, `grid`, `v`
and `dt` arguments and return value of `c` at offset `offset`).
limiter: flux limiter function that evaluates the portion of the correction
(high_accuracy - low_accuracy) to add to low_accuracy solution based on
the ratio of the consequtive gradients. Takes array as input and return
array of weights. For more details see:
https://en.wikipedia.org/wiki/Flux_limiter
Returns:
Interpolation method that uses a combination of high and low order methods
to produce monotonic interpolation method.
"""
def tvd_interpolation(
c: GridVariable,
offset: Tuple[float, ...],
v: GridVariableVector,
dt: float,
) -> GridVariable:
"""Interpolated `c` to offset `offset`."""
for axis, axis_offset in enumerate(offset):
interpolation_offset = tuple([
c_offset if i != axis else axis_offset
for i, c_offset in enumerate(c.offset)
])
if interpolation_offset != c.offset:
if interpolation_offset[axis] - c.offset[axis] != 0.5:
raise NotImplementedError('tvd_interpolation only supports forward '
'interpolation to control volume faces.')
c_low = upwind(c, offset, v, dt)
c_high = interpolation_fn(c, offset, v, dt)
# because we are interpolating to the right we are using 2 points ahead
# and 2 points behind: `c`, `c_left`.
c_left = c.shift(-1, axis)
c_right = c.shift(1, axis)
c_next_right = c.shift(2, axis)
# Velocities of different sign are evaluated with limiters at different
# points. See equations (4.34) -- (4.39) from the reference above.
positive_u_r = safe_div(c.data - c_left.data, c_right.data - c.data)
negative_u_r = safe_div(c_next_right.data - c_right.data,
c_right.data - c.data)
positive_u_phi = grids.GridArray(
limiter(positive_u_r), c_low.offset, c.grid)
negative_u_phi = grids.GridArray(
limiter(negative_u_r), c_low.offset, c.grid)
u = v[axis]
phi = grids.applied(jnp.where)(
u.array > 0, positive_u_phi, negative_u_phi)
c_interpolated = c_low.array - (c_low.array - c_high.array) * phi
c = grids.GridVariable(
grids.GridArray(c_interpolated.data, interpolation_offset, c.grid),
c.bc)
return c
return tvd_interpolation
# TODO(pnorgaard) Consider changing c to GridVariable
# Not required since no .shift() method is used
def point_interpolation(
point: Array,
c: GridArray,
order: int = 1,
mode: str = 'nearest',
cval: float = 0.0,
) -> jax.Array:
"""Interpolate `c` at `point`.
Args:
point: length N 1-D Array. The point to interpolate to.
c: N-dimensional GridArray. The values that will be interpolated.
order: Integer in the range 0-1. The order of the spline interpolation.
mode: one of {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}.
The `mode` parameter determines how the input array is extended
beyond its boundaries. Default is 'constant'. Behavior for each valid
value is as follows:
'reflect' (`d c b a | a b c d | d c b a`)
The input is extended by reflecting about the edge of the last
pixel.
'constant' (`k k k k | a b c d | k k k k`)
The input is extended by filling all values beyond the edge with
the same constant value, defined by the `cval` parameter.
'nearest' (`a a a a | a b c d | d d d d`)
The input is extended by replicating the last pixel.
'mirror' (`d c b | a b c d | c b a`)
The input is extended by reflecting about the center of the last
pixel.
'wrap' (`a b c d | a b c d | a b c d`)
The input is extended by wrapping around to the opposite edge.
cval: Value to fill past edges of input if `mode` is 'constant'. Default 0.0
Returns:
the interpolated value at `point`.
"""
point = jnp.asarray(point)
domain_lower, domain_upper = zip(*c.grid.domain)
domain_lower = jnp.array(domain_lower)
domain_upper = jnp.array(domain_upper)
shape = jnp.array(c.grid.shape)
offset = jnp.array(c.offset)
# For each dimension `i` in point,
# The map from `point[i]` to index is linear.
# index(domain_lower[i]) = -offset[i]
# index(domain_upper[i]) = shape[i] - offset[i]
# This is easily vectorized as
index = (-offset + (point - domain_lower) * shape /
(domain_upper - domain_lower))
return jax.scipy.ndimage.map_coordinates(
c.data, coordinates=index, order=order, mode=mode, cval=cval)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.interpolation."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.base import test_util
import numpy as np
import scipy.interpolate as spi
def periodic_grid_variable(data, offset, grid):
return grids.GridVariable(
array=grids.GridArray(data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
class LinearInterpolationTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_offset_too_short',
shape=(10, 10, 10),
step=(1., 1., 1.),
offset=(2., 3.)),
dict(testcase_name='_offset_too_long',
shape=(10, 10),
step=(1., 1.),
offset=(2., 3., 4.))
)
def testRaisesForInvalidOffset(self, shape, step, offset):
"""Test that incompatible offsets raise an exception."""
grid = grids.Grid(shape, step)
u = periodic_grid_variable(jnp.ones(shape), jnp.zeros(shape), grid)
with self.assertRaises(ValueError):
interpolation.linear(u, offset)
@parameterized.named_parameters(
dict(testcase_name='_1D',
shape=(100,),
step=(.1,),
f=(lambda x: np.random.RandomState(123).randn(*x[0].shape)),
initial_offset=(-.5,),
final_offset=(.5,)),
dict(testcase_name='_2D',
shape=(100, 100),
step=(1., 1.),
f=(lambda xy: np.random.RandomState(231).randn(*xy[0].shape)),
initial_offset=(1., 0.),
final_offset=(0., 0.)),
dict(testcase_name='_3D',
shape=(100, 100, 100),
step=(.3, .4, .5),
f=(lambda xyz: np.random.RandomState(312).randn(*xyz[0].shape)),
initial_offset=(0., 1., 0.),
final_offset=(.5, .5, .5)),
)
def testEquivalenceWithScipy(
self, shape, step, f, initial_offset, final_offset):
"""Tests that interpolation is close to results of `scipy.interpolate`."""
grid = grids.Grid(shape, step)
initial_mesh = grid.mesh(offset=initial_offset)
initial_axes = grid.axes(offset=initial_offset)
initial_u = periodic_grid_variable(f(initial_mesh), initial_offset, grid)
final_mesh = grid.mesh(offset=final_offset)
final_u = interpolation.linear(initial_u, final_offset)
expected_data = spi.interpn(initial_axes,
initial_u.data,
jnp.stack(final_mesh, -1),
method='linear',
bounds_error=False)
# Scipy does not support periodic boundaries so we compare only valid
# values.
valid = np.where(~jnp.isnan(expected_data))
self.assertAllClose(
expected_data[valid], final_u.data[valid], atol=1e-4)
self.assertAllClose(final_offset, final_u.offset)
class UpwindInterpolationTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_2D',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_offset=(1., .5),
u_offset=(.5, 1.)),
dict(testcase_name='_3D',
grid_shape=(10, 10, 10),
grid_step=(1., 1., 1.),
c_offset=(.5, 1., .5),
u_offset=(.5, .5, 1.)),
)
def testRaisesForInvalidOffset(
self, grid_shape, grid_step, c_offset, u_offset):
"""Test that incompatible offsets raise an exception."""
grid = grids.Grid(grid_shape, grid_step)
c = periodic_grid_variable(jnp.ones(grid_shape), c_offset, grid)
with self.assertRaises(grids.InconsistentOffsetError):
interpolation.upwind(c, u_offset, None)
@parameterized.named_parameters(
dict(testcase_name='_2D_positive_velocity',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: jnp.ones((10, 10)),
u_offset=(.5, 1.),
u_axis=1,
expected_data=lambda: jnp.arange(10. * 10.).reshape((10, 10))),
dict(testcase_name='_2D_negative_velocity',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: -1. * jnp.ones((10, 10)),
u_offset=(1., .5),
u_axis=0,
expected_data=lambda: jnp.roll( # pylint: disable=g-long-lambda
jnp.arange(10. * 10.).reshape((10, 10)), shift=-1, axis=0)),
dict(testcase_name='_2D_negative_velocity_large_offset',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: -1. * jnp.ones((10, 10)),
u_offset=(2., .5),
u_axis=0,
expected_data=lambda: jnp.roll( # pylint: disable=g-long-lambda
jnp.arange(10. * 10.).reshape((10, 10)), shift=-2, axis=0)),
dict(testcase_name='_2D_integer_offset',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(0.5, 1),
u_data=lambda: -1. * jnp.ones((10, 10)),
u_offset=(0.5, 0),
u_axis=1,
expected_data=lambda: jnp.roll( # pylint: disable=g-long-lambda
jnp.arange(10. * 10.).reshape((10, 10)), shift=1, axis=1)),
)
def testCorrectness(self, grid_shape, grid_step, c_data, c_offset, u_data,
u_offset, u_axis, expected_data):
grid = grids.Grid(grid_shape, grid_step)
initial_c = periodic_grid_variable(c_data(), c_offset, grid)
u = periodic_grid_variable(u_data(), u_offset, grid)
v = tuple(
u if axis == u_axis else None for axis, _ in enumerate(u_offset)
)
final_c = interpolation.upwind(initial_c, u_offset, v)
self.assertAllClose(expected_data(), final_c.data)
self.assertAllClose(u_offset, final_c.offset)
class LaxWendroffInterpolationTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_2D',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_offset=(1., .5),
u_offset=(.5, 1.)),
dict(testcase_name='_3D',
grid_shape=(10, 10, 10),
grid_step=(1., 1., 1.),
c_offset=(.5, 1., .5),
u_offset=(.5, .5, 1.)),
)
def testRaisesForInvalidOffset(
self, grid_shape, grid_step, c_offset, u_offset):
"""Test that incompatible offsets raise an exception."""
grid = grids.Grid(grid_shape, grid_step)
c = periodic_grid_variable(jnp.ones(grid_shape), c_offset, grid)
with self.assertRaises(grids.InconsistentOffsetError):
interpolation.lax_wendroff(c, u_offset, v=None, dt=0.)
@parameterized.named_parameters(
dict(
testcase_name='_2D_positive_velocity_courant=1',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: jnp.ones((10, 10)),
u_offset=(.5, 1.),
u_axis=1,
dt=1., # courant = u * dt / grid_step = 1
expected_data=lambda: jnp.arange(10. * 10.).reshape((10, 10))),
dict(
testcase_name='_2D_negative_velocity_courant=-1',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: -1. * jnp.ones((10, 10)),
u_offset=(.5, 1.),
u_axis=1,
dt=1., # courant = u * dt / grid_step = -1
expected_data=lambda: jnp.roll( # pylint: disable=g-long-lambda
jnp.arange(10. * 10.).reshape((10, 10)), shift=-1, axis=1)),
dict(
testcase_name='_2D_positive_velocity_courant=0',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: jnp.ones((10, 10)),
u_offset=(.5, 1.),
u_axis=1,
dt=0., # courant = u * dt / grid_step = 0
# for courant = 1, result is the average of cell and upwind
expected_data=lambda: 0.5 * ( # pylint: disable=g-long-lambda
jnp.arange(10. * 10.).reshape((10, 10)) + jnp.roll(
jnp.arange(10. * 10.).reshape((10, 10)), shift=-1, axis=1))),
dict(
testcase_name='_2D_negative_velocity_courant=0',
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: -1. * jnp.ones((10, 10)),
u_offset=(.5, 1.),
u_axis=1,
dt=0., # courant = u * dt / grid_step = 0
# for courant = 1, result is the average of cell and upwind
expected_data=lambda: 0.5 * ( # pylint: disable=g-long-lambda
jnp.arange(10. * 10.).reshape((10, 10)) + jnp.roll(
jnp.arange(10. * 10.).reshape((10, 10)), shift=-1, axis=1))),
)
def testCorrectness(self, grid_shape, grid_step, c_data, c_offset, u_data,
u_offset, u_axis, dt, expected_data):
grid = grids.Grid(grid_shape, grid_step)
initial_c = periodic_grid_variable(c_data(), c_offset, grid)
u = periodic_grid_variable(u_data(), u_offset, grid)
v = tuple(
u if axis == u_axis else None for axis, _ in enumerate(u_offset)
)
final_c = interpolation.lax_wendroff(initial_c, u_offset, v, dt)
self.assertAllClose(expected_data(), final_c.data)
self.assertAllClose(u_offset, final_c.offset)
class ApplyTvdLimiterTest(test_util.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='_case_where_lax_wendroff_is_same_as_upwind',
interpolation_fn=interpolation.lax_wendroff,
limiter=interpolation.van_leer_limiter,
grid_shape=(10, 10),
grid_step=(1., 1.),
c_data=lambda: jnp.arange(10. * 10.).reshape((10, 10)),
c_offset=(.5, .5),
u_data=lambda: jnp.ones((10, 10)),
u_offset=(.5, 1.),
u_axis=1,
dt=1.,
expected_fn=interpolation.upwind),
)
def testCorrectness(self, interpolation_fn, limiter, grid_shape, grid_step,
c_data, c_offset, u_data, u_offset, u_axis, dt,
expected_fn):
c_interpolation_fn = interpolation.apply_tvd_limiter(
interpolation_fn, limiter)
grid = grids.Grid(grid_shape, grid_step)
initial_c = periodic_grid_variable(c_data(), c_offset, grid)
u = periodic_grid_variable(u_data(), u_offset, grid)
v = tuple(
u if axis == u_axis else None for axis, _ in enumerate(u_offset)
)
final_c = c_interpolation_fn(initial_c, u_offset, v, dt)
expected = expected_fn(initial_c, u_offset, v, dt)
self.assertAllClose(expected, final_c)
self.assertAllClose(u_offset, final_c.offset)
class PointInterpolationTest(test_util.TestCase):
def test_eval_of_2d_function(self):
grid_shape = (50, 111)
offset = (0.5, 0.8)
grid = grids.Grid(grid_shape, domain=((0., jnp.pi),) * 2)
xy_grid_pts = jnp.stack(grid.mesh(offset), axis=-1)
func = lambda xy: jnp.sin(xy[..., 0]) * xy[..., 1]
c = grids.GridArray(data=func(xy_grid_pts), offset=offset, grid=grid)
vec_interp = jax.vmap(
interpolation.point_interpolation, in_axes=(0, None), out_axes=0)
# At the grid points, accuracy should be excellent...almost perfect up to
# epsilon.
xy_grid_pts = jnp.reshape(xy_grid_pts, (-1, 2)) # Reshape for vmap.
self.assertAllClose(
vec_interp(xy_grid_pts, c), func(xy_grid_pts), atol=1e-6)
# At random points, tol is guided by standard first order method heuristics.
atol = 1 / min(*grid_shape)
xy_random = np.random.RandomState(0).rand(100, 2) * np.pi
self.assertAllClose(
vec_interp(xy_random, c), func(xy_random), atol=atol)
def test_order_equals_0_is_piecewise_constant(self):
grid_shape = (3,)
offset = (0.5,)
grid = grids.Grid(grid_shape, domain=((0., 1.),))
x_grid_pts, = grid.mesh(offset=offset)
func = lambda x: 2 * x**2
c = grids.GridArray(data=func(x_grid_pts), offset=offset, grid=grid)
def _nearby_points(value):
eps = grid.step[0] / 3
return [value - eps, value, value + eps]
def _interp(x):
return interpolation.point_interpolation(x, c, order=0)
for x in x_grid_pts:
for near_x in _nearby_points(x):
np.testing.assert_allclose(func(x), _interp(near_x))
def test_mode_and_cval_args_are_used(self):
# Just test that the mode arg is used, in the simplest way.
# We don't have to check the correctnesss of these args, since
# jax.scipy.ndimage.map_coordinates is tested separately.
grid_shape = (3,)
offset = (0.5,)
grid = grids.Grid(grid_shape, domain=((0., 1.),))
x_grid_pts, = grid.mesh(offset=offset)
c = grids.GridArray(
# Just use random points. Won't affect anything.
data=x_grid_pts * 0.1, offset=offset, grid=grid)
# Outside the domain, the passing of mode='constant' and cval=1234 results
# in this value being used.
outside_domain_point = 10.
self.assertAllClose(
1234,
interpolation.point_interpolation(
outside_domain_point, c, mode='constant', cval=1234))
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for computing and applying pressure."""
from typing import Callable, Optional
import jax.numpy as jnp
import jax.scipy.sparse.linalg
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
Array = grids.Array
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
# TODO(pnorgaard) Implement bicgstab for non-symmetric operators
def _rhs_transform(
u: grids.GridArray,
bc: boundaries.BoundaryConditions,
) -> Array:
"""Transform the RHS of pressure projection equation for stability.
In case of poisson equation, the kernel is subtracted from RHS for stability.
Args:
u: a GridArray that solves ∇²x = u.
bc: specifies boundary of x.
Returns:
u' s.t. u = u' + kernel of the laplacian.
"""
u_data = u.data
for axis in range(u.grid.ndim):
if bc.types[axis][0] == boundaries.BCType.NEUMANN and bc.types[axis][
1] == boundaries.BCType.NEUMANN:
# if all sides are neumann, poisson solution has a kernel of constant
# functions. We substact the mean to ensure consistency.
u_data = u_data - jnp.mean(u_data)
return u_data
def solve_cg(
v: GridVariableVector,
q0: GridVariable,
pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None,
rtol: float = 1e-6,
atol: float = 1e-6,
maxiter: Optional[int] = None) -> GridArray:
"""Conjugate gradient solve for the pressure such that continuity is enforced.
Returns a pressure correction `q` such that `div(v - grad(q)) == 0`.
The relationship between `q` and our actual pressure estimate is given by
`p = q * density / dt`.
Args:
v: the velocity field.
q0: an initial value, or "guess" for the pressure correction. A common
choice is the correction from the previous time step. Also specifies the
boundary conditions on `q`.
pressure_bc: the boundary condition to assign to pressure. If None,
boundary condition is infered from velocity.
rtol: relative tolerance for convergence.
atol: absolute tolerance for convergence.
maxiter: optional int, the maximum number of iterations to perform.
Returns:
A pressure correction `q` such that `div(v - grad(q))` is zero.
"""
# TODO(jamieas): add functionality for non-uniform density.
rhs = fd.divergence(v)
if pressure_bc is None:
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
def laplacian_with_bcs(array: GridArray) -> GridArray:
variable = pressure_bc.impose_bc(array)
return fd.laplacian(variable)
q, _ = jax.scipy.sparse.linalg.cg(
laplacian_with_bcs,
rhs,
x0=q0.array,
tol=rtol,
atol=atol,
maxiter=maxiter)
return q
def solve_fast_diag(
v: GridVariableVector,
q0: Optional[grids.GridArray] = None,
pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None,
implementation: Optional[str] = None,
) -> grids.GridArray:
"""Solve for pressure using the fast diagonalization approach.
To support backward compatibility, if the pressure_bc are not provided and
velocity has all periodic boundaries, pressure_bc are assigned to be periodic.
Args:
v: a tuple of velocity values for each direction.
q0: the starting guess for the pressure.
pressure_bc: the boundary condition to assign to pressure. If None,
boundary condition is infered from velocity.
implementation: how to implement fast diagonalization.
For non-periodic BCs will automatically be matmul.
Returns:
A solution to the PPE equation.
"""
del q0 # unused
if pressure_bc is None:
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
if boundaries.has_all_periodic_boundary_conditions(*v):
circulant = True
else:
circulant = False
# only matmul implementation supports non-circulant matrices
implementation = 'matmul'
rhs = fd.divergence(v)
laplacians = array_utils.laplacian_matrix_w_boundaries(
rhs.grid, rhs.offset, pressure_bc)
rhs_transformed = _rhs_transform(rhs, pressure_bc)
pinv = fast_diagonalization.pseudoinverse(
laplacians,
rhs_transformed.dtype,
hermitian=True,
circulant=circulant,
implementation=implementation)
return grids.GridArray(pinv(rhs_transformed), rhs.offset, rhs.grid)
def solve_fast_diag_channel_flow(
v: GridVariableVector,
q0: Optional[grids.GridArray] = None,
pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None,
) -> grids.GridArray:
"""Applies solve_fast_diag for channel flow.
Args:
v: a tuple of velocity values for each direction.
q0: the starting guess for the pressure.
pressure_bc: the boundary condition to assign to pressure. If None,
boundary condition is infered from velocity.
Returns:
A solutiion to the PPE equation.
"""
if pressure_bc is None:
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
return solve_fast_diag(v, q0, pressure_bc, implementation='matmul')
def projection(
v: GridVariableVector,
solve: Callable = solve_fast_diag,
) -> GridVariableVector:
"""Apply pressure projection to make a velocity field divergence free."""
grid = grids.consistent_grid(*v)
pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
q0 = grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid)
q0 = pressure_bc.impose_bc(q0)
q = solve(v, q0, pressure_bc)
q = pressure_bc.impose_bc(q)
q_grad = fd.forward_difference(q)
v_projected = tuple(
u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad))
return v_projected
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.pressure."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import pressure
from jax_cfd.base import test_util
import numpy as np
USE_FLOAT64 = True
solve_cg = functools.partial(pressure.solve_cg, atol=1e-6, maxiter=10**5)
class PressureTest(test_util.TestCase):
def setUp(self):
jax.config.update('jax_enable_x64', USE_FLOAT64)
super(PressureTest, self).setUp()
def poisson_setup(self, bc, offset):
rs = np.random.RandomState(0)
b = rs.randn(4, 4).astype(np.float32)
grid = grids.Grid((4, 4), domain=((0, 4), (0, 4))) # has step = 1.0
b = grids.GridArray(b, offset, grid)
a = array_utils.laplacian_matrix_w_boundaries(grid, offset, bc)
b_transformed = pressure._rhs_transform(
bc.trim_boundary(b), bc)
a_inv = fast_diagonalization.pseudoinverse(
a, b.dtype, hermitian=True, circulant=False, implementation='matmul')
x = a_inv(b_transformed)
x = grids.GridArray(x, b.offset, grid)
# laplacian is defined only on the interior
x = grids.GridVariable(x, bc)
x = fd.laplacian(x).data
return x, b
@parameterized.named_parameters(
dict(testcase_name='_1D_cg',
shape=(301,),
solve=solve_cg,
step=(.1,),
seed=111),
dict(testcase_name='_2D_cg',
shape=(100, 100),
solve=solve_cg,
step=(1., 1.),
seed=222),
dict(testcase_name='_3D_cg',
shape=(10, 10, 10),
solve=solve_cg,
step=(.1, .1, .1),
seed=333),
dict(testcase_name='_1D_fast_diag',
shape=(301,),
solve=pressure.solve_fast_diag,
step=(.1,),
seed=111),
dict(testcase_name='_2D_fast_diag',
shape=(100, 100),
solve=pressure.solve_fast_diag,
step=(1., 1.),
seed=222),
dict(testcase_name='_3D_fast_diag',
shape=(10, 10, 10),
solve=pressure.solve_fast_diag,
step=(.1, .1, .1),
seed=333),
)
def test_pressure_correction_periodic_bc(
self, shape, solve, step, seed):
"""Returned velocity should be divergence free."""
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
# The uncorrected velocity is a 1 + a small amount of noise in each
# dimension.
ks = jax.random.split(jax.random.PRNGKey(seed), 2 * len(shape))
v = tuple(
grids.GridArray(1. + .3 * jax.random.normal(k, shape), offset, grid)
for k, offset in zip(ks[:grid.ndim], grid.cell_faces))
v = tuple(grids.GridVariable(u, bc) for u in v)
v_corrected = pressure.projection(v, solve)
# The corrected velocity should be divergence free.
div = fd.divergence(v_corrected)
for u, u_corrected in zip(v, v_corrected):
np.testing.assert_allclose(u.offset, u_corrected.offset)
np.testing.assert_allclose(div.data, 0., atol=1e-4)
@parameterized.named_parameters(
dict(testcase_name='_1D_cg',
shape=(10,),
solve=solve_cg,
step=(.1,),
seed=111),
dict(testcase_name='_2D_cg',
shape=(10, 10),
solve=solve_cg,
step=(.1, .1),
seed=222),
dict(testcase_name='_3D_cg',
shape=(10, 10, 10),
solve=solve_cg,
step=(.1, .1, .1),
seed=333),
)
def test_pressure_correction_dirichlet_velocity_bc(
self, shape, solve, step, seed):
"""Returned velocity should be divergence free."""
grid = grids.Grid(shape, step)
velocity_bc = (boundaries.dirichlet_boundary_conditions(
grid.ndim),) * grid.ndim
# The uncorrected velocity is zero + a small amount of noise in each
# dimension.
ks = jax.random.split(jax.random.PRNGKey(seed), 2 * grid.ndim)
v = tuple(
grids.GridArray(0. + .3 * jax.random.normal(k, shape), offset, grid)
for k, offset in zip(ks[:grid.ndim], grid.cell_faces))
# Set boundary velocity to zero
masks = grids.domain_interior_masks(grid)
self.assertLen(masks, grid.ndim)
v = (m * u for m, u in zip(masks, v))
# Associate boundary conditions
v = tuple(grids.GridVariable(u, u_bc) for u, u_bc in zip(v, velocity_bc))
self.assertLen(v, grid.ndim)
# Apply pressure correction
v_corrected = pressure.projection(v, solve)
# The corrected velocity should be divergence free.
div = fd.divergence(v_corrected)
for u, u_corrected in zip(v, v_corrected):
np.testing.assert_allclose(u.offset, u_corrected.offset)
np.testing.assert_allclose(div.data, 0., atol=1e-4)
@parameterized.parameters(
dict(ndim=2, solve=pressure.solve_cg),
dict(ndim=2, solve=pressure.solve_fast_diag_channel_flow),
dict(ndim=3, solve=pressure.solve_cg),
dict(ndim=3, solve=pressure.solve_fast_diag_channel_flow),
)
def test_pressure_correction_mixed_velocity_bc(self, ndim, solve):
"""Returned velocity should be divergence free."""
shape = (20,) * ndim
grid = grids.Grid(shape, step=0.1)
velocity_bc = (boundaries.channel_flow_boundary_conditions(ndim),) * ndim
def rand_array(seed):
key = jax.random.split(jax.random.PRNGKey(seed))
return jax.random.normal(key[0], shape)
v = tuple(
grids.GridArray(
1. + .3 * rand_array(seed=d), offset=grid.cell_faces[d], grid=grid)
for d in range(ndim))
# Associate and enforce boundary conditions
v = tuple(grids.GridVariable(u, u_bc).impose_bc()
for u, u_bc in zip(v, velocity_bc))
# y-velocity = 0 for the edge y=y_max (homogeneous Diriclet BC)
# y-velocity on lower y-boundary is not on an edge
# Note, x- and z-velocity do not have an edge value on the y-boundaries
self.assertAllClose(v[1].data[:, -1, ...], 0)
# Apply pressure correction
v_corrected = pressure.projection(v, solve)
# The corrected velocity should be divergence free.
div = fd.divergence(v_corrected)
for u, u_corrected in zip(v, v_corrected):
np.testing.assert_allclose(u.offset, u_corrected.offset)
np.testing.assert_allclose(div.data, 0., atol=1e-4)
@parameterized.parameters(((1.0, 0.5),), ((1.0, 1.0),), ((1.0, 0.0),))
def test_poisson_periodic_and_dirichlet(self, offset):
bc = boundaries.periodic_and_dirichlet_boundary_conditions()
x, b = self.poisson_setup(bc, offset)
self.assertAllClose(x, bc.trim_boundary(b).data, atol=1e-5)
@parameterized.parameters(((1.0, 0.5),), ((0.5, 0.5),))
def test_poisson_periodic_and_neumann(self, offset):
bc = boundaries.periodic_and_neumann_boundary_conditions()
x, b = self.poisson_setup(bc, offset)
self.assertAllClose(x, b.data - b.data.mean(), atol=1e-5)
@parameterized.parameters(((1.0, 0.5),), ((1.0, 1.0),), ((1.0, 0.0),))
def test_poisson_2d_periodic(self, offset):
bc = boundaries.periodic_boundary_conditions(2)
x, b = self.poisson_setup(bc, offset)
self.assertAllClose(x, b.data - b.data.mean(), atol=1e-5)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Resize velocity fields to a different resolution grid."""
from typing import Optional, Tuple, Union
import jax
import jax.numpy as jnp
from jax_cfd.base import array_utils as arr_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import interpolation
import numpy as np
Array = grids.Array
Field = Tuple[Array, ...]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
RawArray = jnp.ndarray
def downsample_staggered_velocity_component(u: Array, direction: int,
factor: int) -> Array:
"""Downsamples `u`, an array of velocities in the given `direction`.
Downsampling consists of the following steps:
* Establish new downsampled control volumes. Each will consist of
`factor ** dimension` of the fine-grained control volumes.
* Discard all of the `u` values that do not lie on a face of the new control
volume in `direction`.
* Compute the mean of all `u` values that lie on each control volume face in
the given `direction`.
This procedure guarantees that if our source velocity has zero divergence
(i.e., corresponds to an incompressible flow), the downsampled velocity field
also has zero divergence.
For example,
```
u = [[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
w = downsample_velocity(u, direction=0, factor=2)
assert w == np.array([[4.5, 6.5],
[12.5, 14.5]])
```
Args:
u: an array of velocity values.
direction: an integer indicating the direction of the velocities `u`.
factor: the factor by which to downsample.
Returns:
Coarse-grained array, reduced in size by ``factor`` along each dimension.
"""
w = arr_utils.slice_along_axis(u, direction, slice(factor - 1, None, factor))
block_size = tuple(1 if j == direction else factor for j in range(u.ndim))
return arr_utils.block_reduce(w, block_size, jnp.mean)
def top_hat_downsample(
source_grid: grids.Grid,
destination_grid: grids.Grid,
variables: GridVariableVector,
filter_size: Optional[Union[int, Tuple[int, ...]]] = None
) -> GridVariableVector:
"""Filters each variable by filter_size and subsamples onto destination_grid.
Downsampling consists of the following steps:
* Filter the data by averaging
* Interpolate the averaged data onto the destination_grid
This procedure corresponds to standard top-hat filter + comb downsampling.
Note that the filter size does not have to equal the factor difference between
the two grids. The intended use case is for filter size >= factor.
Args:
source_grid: the grid of variable u. Note: this is legacy implementation,
variables[i] is an instance of GridVariable and has a grid attribute.
destination_grid: the grid on which to interpolate filtered variables.
variables: a tuple of GridVariables. Note that the grid attribute of each
variable has to agree with source_grid.
filter_size: the number of grid points used in the filter. If it's an int,
it specifies the same number of points to filter in all directions. If it
is a tuple. each direction is specified separately.
Returns:
a tuple of GridVariables interpolated on destination_grid
"""
# assumes different filtering can be done in different directions
factor = tuple(
dx / dx_source
for dx, dx_source in zip(destination_grid.step, source_grid.step))
if filter_size is None:
filter_size = factor
if isinstance(filter_size, int):
filter_size = tuple(filter_size for _ in range(source_grid.ndim))
assert destination_grid.domain == source_grid.domain
assert all([round(f) == f for f in factor])
assert all([round(f) == f for f in filter_size]) # this can be relaxed
acceptable_filter = lambda f: f % 2 == 0 or f == 1
assert all(map(acceptable_filter,
filter_size)) # only even filters are implemented
assert all(list(map(acceptable_filter,
factor))) # only even factors are implemented
# filter has to be at least as large as the factor.
assert all(filt >= f for f, filt in zip(factor, filter_size))
result = []
for c in variables:
if c.grid != source_grid:
raise grids.InconsistentGridError(
f'source_grid for downsampling is {source_grid}, but c is defined'
f' on {c.grid}')
bc = c.bc
offset = c.offset
center_offset = tuple(
0.5 if f > 1 else o for o, f in zip(offset, filter_size))
c_centered = interpolation.linear(c, center_offset).array
center_offset = np.array(center_offset)
grid_shape = np.array(source_grid.shape)
for axis in range(c.grid.ndim):
c_centered = bc.pad(
c_centered,
round(filter_size[axis]) // 2,
axis=axis,
mode=boundaries.Padding.MIRROR)
c_centered = bc.pad(
c_centered,
-(round(filter_size[axis]) // 2),
axis=axis,
mode=boundaries.Padding.MIRROR)
convolution_filter = jnp.ones(round(
filter_size[axis])) / filter_size[axis]
convolve_1d = lambda arr, convolution_filter=convolution_filter: jnp.convolve( # pylint: disable=g-long-lambda
arr, convolution_filter, 'valid')
axes = list(range(source_grid.ndim))
axes.remove(axis)
for ax in axes:
convolve_1d = jax.vmap(convolve_1d, in_axes=ax, out_axes=ax)
c_centered = convolve_1d(c_centered.data)
if filter_size[axis] > 1:
if np.isclose(offset[axis], 0):
start = 0
end = c_centered.shape[axis] - 1
elif np.isclose(offset[axis], 0.5):
start = int(factor[axis]) // 2
end = None
elif np.isclose(offset[axis], 1.0):
start = int(factor[axis])
end = None
else:
raise NotImplementedError(f'offset {offset} is not implemented.')
else:
start = 0
end = None
c_centered = arr_utils.slice_along_axis(
c_centered, axis, slice(start, end, int(factor[axis])))
center_offset[axis] = offset[axis]
grid_shape[axis] = destination_grid.shape[axis]
c_centered = grids.GridArray(
c_centered,
offset=tuple(center_offset),
grid=grids.Grid(shape=tuple(grid_shape), domain=source_grid.domain))
c = grids.GridVariable(c_centered, bc)
result.append(c)
return tuple(result)
def downsample_staggered_velocity(
source_grid: grids.Grid,
destination_grid: grids.Grid,
velocity: Union[Field, GridArrayVector, GridVariableVector],
):
"""Downsamples each component of `v` by `factor`."""
factor = destination_grid.step[0] / source_grid.step[0]
assert destination_grid.domain == source_grid.domain
assert round(factor) == factor, factor
result = []
for j, u in enumerate(velocity):
if isinstance(u, GridVariable):
def downsample(u: GridVariable, direction: int,
factor: int) -> GridVariable:
if u.grid != source_grid:
raise grids.InconsistentGridError(
f'source_grid for downsampling is {source_grid}, but u is defined'
f' on {u.grid}')
array = downsample_staggered_velocity_component(u.data, direction,
round(factor))
grid_array = GridArray(array, offset=u.offset, grid=destination_grid)
return GridVariable(grid_array, bc=u.bc)
elif isinstance(u, GridArray):
def downsample(u: GridArray, direction: int, factor: int) -> GridArray:
if u.grid != source_grid:
raise grids.InconsistentGridError(
f'source_grid for downsampling is {source_grid}, but u is defined'
f' on {u.grid}')
array = downsample_staggered_velocity_component(u.data, direction,
round(factor))
return GridArray(array, offset=u.offset, grid=destination_grid)
else:
downsample = downsample_staggered_velocity_component
result.append(downsample(u, j, round(factor)))
return tuple(result)
# TODO(dresdner) gin usage should be restricted to jax_cfd.ml
def downsample_spectral(_: grids.Grid, destination_grid: grids.Grid,
signal_hat: RawArray):
"""Downsamples a 2D signal in the Fourier basis to the `destination_grid`."""
kx, ky = destination_grid.rfft_axes()
(num_x,), (num_y,) = kx.shape, ky.shape
input_num_x, _ = signal_hat.shape
downed = jnp.concatenate(
[signal_hat[:num_x // 2, :num_y], signal_hat[-num_x // 2:, :num_y]])
scale = (num_x / input_num_x)
downed *= scale**2
return downed
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.base.resize."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences
from jax_cfd.base import grids
from jax_cfd.base import initial_conditions
from jax_cfd.base import resize
from jax_cfd.base import test_util
import numpy as np
BCType = boundaries.BCType
def periodic_grid_variable(data, offset, grid):
return grids.GridVariable(
array=grids.GridArray(data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
class ResizeTest(test_util.TestCase):
@parameterized.parameters(
dict(u=np.array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]),
direction=0,
factor=2,
expected=np.array([[4.5, 6.5],
[12.5, 14.5]])),
dict(u=np.array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]),
direction=1,
factor=2,
expected=np.array([[3., 5.],
[11., 13.]])),
)
def testDownsampleVelocityComponent(self, u, direction, factor, expected):
"""Test `downsample_array` produces the expected results."""
actual = resize.downsample_staggered_velocity_component(
u, direction, factor)
self.assertAllClose(expected, actual)
def testDownsampleVelocity(self):
source_grid = grids.Grid((4, 4), domain=[(0, 1), (0, 1)])
destination_grid = grids.Grid((2, 2), domain=[(0, 1), (0, 1)])
u = np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
])
expected = (np.array([[4.5, 6.5],
[12.5, 14.5]]), np.array([[3., 5.], [11., 13.]]))
with self.subTest('ArrayField'):
velocity = (u, u)
actual = resize.downsample_staggered_velocity(source_grid,
destination_grid, velocity)
self.assertAllClose(expected, actual)
with self.subTest('GridArrayVector'):
velocity = (grids.GridArray(u, offset=(1, 0), grid=source_grid),
grids.GridArray(u, offset=(0, 1), grid=source_grid))
actual = resize.downsample_staggered_velocity(source_grid,
destination_grid, velocity)
expected_combined = (
grids.GridArray(expected[0], offset=(1, 0), grid=destination_grid),
grids.GridArray(expected[1], offset=(0, 1), grid=destination_grid))
self.assertAllClose(expected_combined[0], actual[0])
self.assertAllClose(expected_combined[1], actual[1])
with self.subTest('GridArrayVector: Inconsistent Grids'):
with self.assertRaisesRegex(grids.InconsistentGridError,
'source_grid for downsampling'):
different_grid = grids.Grid((4, 4), domain=[(-2, 2), (0, 1)])
velocity = (grids.GridArray(u, offset=(1, 0), grid=different_grid),
grids.GridArray(u, offset=(0, 1), grid=different_grid))
resize.downsample_staggered_velocity(source_grid,
destination_grid, velocity)
with self.subTest('GridVariableVector'):
velocity = (periodic_grid_variable(u, (1, 0), source_grid),
periodic_grid_variable(u, (0, 1), source_grid))
actual = resize.downsample_staggered_velocity(source_grid,
destination_grid, velocity)
expected_combined = (
periodic_grid_variable(expected[0], (1, 0), destination_grid),
periodic_grid_variable(expected[1], (0, 1), destination_grid),
)
self.assertAllClose(expected_combined[0], actual[0])
self.assertAllClose(expected_combined[1], actual[1])
with self.subTest('GridVariableVector: Inconsistent Grids'):
with self.assertRaisesRegex(grids.InconsistentGridError,
'source_grid for downsampling'):
different_grid = grids.Grid((4, 4), domain=[(-2, 2), (0, 1)])
velocity = (
periodic_grid_variable(u, (1, 0), different_grid),
periodic_grid_variable(u, (0, 1), different_grid))
resize.downsample_staggered_velocity(source_grid,
destination_grid, velocity)
def testDownsampleFourierVorticity(self):
with self.subTest('Space2D'):
domain = ((0, 2 * jnp.pi), (0, 2 * jnp.pi))
fine = grids.Grid(((256, 256)), domain=domain)
medium = grids.Grid(((128, 128)), domain=domain)
coarse = grids.Grid(((64, 64)), domain=domain)
v0 = initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(42), fine, maximum_velocity=7, peak_wavenumber=1)
fine_signal = finite_differences.curl_2d(v0).data
# test that fine -> medium -> coarse == fine -> coarse
fine_signal_hat = jnp.fft.rfftn(fine_signal)
fine_to_medium = resize.downsample_spectral(
None, medium, fine_signal_hat)
medium_to_coarse = resize.downsample_spectral(
None, coarse, fine_to_medium)
fine_to_coarse = resize.downsample_spectral(
None, coarse, fine_signal_hat)
self.assertAllClose(fine_to_coarse, medium_to_coarse)
# test that grid -> grid does nothing
self.assertAllClose(
fine_signal_hat,
resize.downsample_spectral(None, fine, fine_signal_hat))
self.assertAllClose(
fine_to_medium,
resize.downsample_spectral(None, medium, fine_to_medium))
with self.subTest('DownsampleWavenumbers'):
fine = grids.Grid((256, 256), domain=((0, 2 * jnp.pi),) * 2)
coarse = grids.Grid((64, 64), domain=((0, 2 * jnp.pi),) * 2)
kx_fine, ky_fine = fine.rfft_mesh()
kx_coarse, ky_coarse = coarse.rfft_mesh()
# (256/64)^2 = 16
kx_down = 16 * resize.downsample_spectral(None, coarse, kx_fine)
ky_down = 16 * resize.downsample_spectral(None, coarse, ky_fine)
self.assertArrayEqual(kx_down, kx_coarse)
self.assertArrayEqual(ky_down, ky_coarse)
class ResizeTopHatTest(test_util.TestCase):
@parameterized.parameters(
# Periodic BC
dict(
input_data=np.array([0, 1, 2, 3]),
input_offset=(.5,),
expected_data=np.array([0.5, 2.5]),
filter_size=None,
destination_grid_shape=2,
),
dict(
input_data=np.array([-1, 1, 2, 3]),
input_offset=(1.,),
expected_data=np.array([.75, 1.75]),
filter_size=None,
destination_grid_shape=2,
),
dict(
input_data=np.array([-1, 1, 2, 3]),
input_offset=(0.,),
expected_data=np.array([
.5,
2.,
]),
filter_size=None,
destination_grid_shape=2,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(1.,),
expected_data=np.array([1.75, 3., 2.75]),
filter_size=4,
destination_grid_shape=3,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(.5,),
expected_data=np.array([2., 2.5, 3.]),
filter_size=4,
destination_grid_shape=3,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(0.,),
expected_data=np.array([2.25, 2, 3.25]),
filter_size=4,
destination_grid_shape=3,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
input_offset=(1.,),
expected_data=np.array([3., 4.]),
filter_size=4,
destination_grid_shape=2,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
input_offset=(.5,),
expected_data=np.array([1.5, 5.5]),
filter_size=4,
destination_grid_shape=2,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
input_offset=(0.,),
expected_data=np.array([3., 4.]),
filter_size=4,
destination_grid_shape=2,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
input_offset=(1.,),
expected_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
filter_size=None,
destination_grid_shape=8,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
input_offset=(.5,),
expected_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
filter_size=None,
destination_grid_shape=8,
),
dict(
input_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
input_offset=(0.,),
expected_data=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
filter_size=None,
destination_grid_shape=8,
),
)
def test_downsample_velocity_1d_grid(self, input_data, input_offset,
expected_data, filter_size,
destination_grid_shape):
source_grid = grids.Grid(
input_data.shape, domain=[
(0, 1),
])
destination_grid = grids.Grid((destination_grid_shape,), domain=[
(0, 1),
])
v = grids.GridVariable(
grids.GridArray(input_data, input_offset, source_grid),
boundaries.periodic_boundary_conditions(1))
(actual,) = resize.top_hat_downsample(source_grid, destination_grid, (v,),
filter_size)
expected = grids.GridVariable(
grids.GridArray(expected_data, input_offset, grid=destination_grid),
boundaries.periodic_boundary_conditions(1))
self.assertAllClose(expected, actual)
@parameterized.parameters(
# Dirichlet BC
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((0.0, 0.0),)),
input_data=np.array([1, 2, 3, 4, 5, 0]),
input_offset=(1.,),
expected_data=np.array([2, 3.25, 0]),
filter_size=4,
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((0.0, 0.0),)),
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(.5,),
expected_data=np.array([.75, 2.5, 1.75]),
filter_size=4,
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((0.0, 0.0),)),
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(0.,),
expected_data=np.array([0, 2, 3.25]),
filter_size=4,
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((0.0, 1.0),)),
input_data=np.array([1, 2, 3, 4, 5, 1]),
input_offset=(1.,),
expected_data=np.array([2, 3.375, 1]),
filter_size=4,
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((0.0, 1.0),)),
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(.5,),
expected_data=np.array([.75, 2.5, 2.25]),
filter_size=4,
),
dict(
bc_types=(((BCType.DIRICHLET, BCType.DIRICHLET),), ((0.0, 1.0),)),
input_data=np.array([0, 1, 2, 3, 4, 5]),
input_offset=(0.,),
expected_data=np.array([0, 2, 3.375]),
filter_size=4,
),
)
def test_downsample_velocity_1d_2x_grid_dirichlet(self, bc_types, input_data,
input_offset, expected_data,
filter_size):
source_grid = grids.Grid(
input_data.shape, domain=[
(0, 1),
])
destination_grid = grids.Grid(
np.array(input_data.shape) // 2, domain=[
(0, 1),
])
v = grids.GridVariable(
grids.GridArray(input_data, input_offset, source_grid),
boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1]))
actual = resize.top_hat_downsample(source_grid, destination_grid, (v,),
filter_size)
expected = (grids.GridVariable(
grids.GridArray(expected_data, input_offset, grid=destination_grid),
boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])),)
self.assertAllClose(expected[0], actual[0])
@parameterized.parameters(
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.PERIODIC, BCType.PERIODIC)), ((None, None),
(None, None))),
input_data=np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
]),
input_offset=(0.5, 1.),
expected_data=np.array([[3, 4], [11, 12]]),
filter_size=None,
),
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)), ((None, None),
(0.0, 1.0))),
input_data=np.array([
[0, 2, 3, 1],
[4, 5, 6, 1],
[8, 9, 10, 1],
[12, 13, 14, 1],
]),
input_offset=(0.5, 1.),
expected_data=np.array([[3.375, 1], [11, 1]]),
filter_size=None,
),
dict(
bc_types=(((BCType.PERIODIC, BCType.PERIODIC),
(BCType.DIRICHLET, BCType.DIRICHLET)), ((None, None),
(0.0, 1.0))),
input_data=np.array([
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
]),
input_offset=(1.0, .5),
expected_data=np.array([[4.5, 6.5], [8.5, 10.5]]),
filter_size=None,
),
)
def test_downsample_velocity_2d_2x_grid(self, bc_types, input_data,
input_offset, expected_data,
filter_size):
source_grid = grids.Grid(input_data.shape, domain=[(0, 1), (0, 1)])
destination_grid = grids.Grid(
np.array(input_data.shape) // 2, domain=[(0, 1), (0, 1)])
v = grids.GridVariable(
grids.GridArray(input_data, input_offset, source_grid),
boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1]))
actual = resize.top_hat_downsample(source_grid, destination_grid, (v,),
filter_size)
expected = (grids.GridVariable(
grids.GridArray(expected_data, input_offset, grid=destination_grid),
boundaries.ConstantBoundaryConditions(bc_types[0], bc_types[1])),)
self.assertAllClose(expected[0], actual[0])
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code for subgrid models."""
import functools
from typing import Any, Callable, Mapping, Optional
import jax
from jax_cfd.base import boundaries
from jax_cfd.base import equations
from jax_cfd.base import finite_differences
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import interpolation
import numpy as np
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = interpolation.InterpolationFn
ViscosityFn = Callable[[grids.GridArrayTensor, GridVariableVector],
grids.GridArrayTensor]
# TODO(pnorgaard) Refactor subgrid_models to interpolate, then differentiate
def smagorinsky_viscosity(
s_ij: grids.GridArrayTensor,
v: GridVariableVector,
dt: Optional[float] = None,
cs: float = 0.2,
interpolate_fn: InterpolationFn = interpolation.linear
) -> grids.GridArrayTensor:
"""Computes eddy viscosity based on Smagorinsky model.
This viscosity model computes scalar eddy viscosity at `grid.cell_center` and
then interpolates it to offsets of the strain rate tesnor `s_ij`. Based on:
https://en.wikipedia.org/wiki/Large_eddy_simulation#Smagorinsky-Lilly_model
Args:
s_ij: strain rate tensor that is equal to the forward finite difference
derivatives of the velocity field `(d(u_i)/d(x_j) + d(u_j)/d(x_i)) / 2`.
v: velocity field, passed to `interpolate_fn`.
dt: integration time step passed to `interpolate_fn`. Can be `None` if
`interpolate_fn` is independent of `dt`. Default: `None`.
cs: the Smagorinsky constant.
interpolate_fn: interpolation method to use for viscosity interpolations.
Returns:
tensor of GridArray's containing values of the eddy viscosity at the
same grid offsets as the strain tensor `s_ij`.
"""
# Present implementation:
# - s_ij is a GridArrayTensor
# - v is converted to a GridVariableVector
# - interpolation method is wrapped so that interpolated quanity is a
# GridArray (rather than GridVariable), using periodic BC.
#
# This should be revised so that s_ij is computed by first interpolating
# velocity and then computing s_ij via finite differences, producing
# a `GridVariableTensor`. Then no wrapper or GridArray/GridVariable
# conversion hacks are needed.
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError('smagorinsky_viscosity only valid for periodic BC.')
bc = grids.unique_boundary_conditions(*v)
def wrapped_interp_fn(c, offset, v, dt):
return interpolate_fn(grids.GridVariable(c, bc), offset, v, dt).array
grid = grids.consistent_grid(*s_ij.ravel(), *v)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
s_ij_offsets = [array.offset for array in s_ij.ravel()]
unique_offsets = list(set(s_ij_offsets))
cell_center = grid.cell_center
interpolate_to_center = lambda x: wrapped_interp_fn(x, cell_center, v, dt)
centered_s_ij = np.vectorize(interpolate_to_center)(s_ij)
# geometric average
cutoff = np.prod(np.array(grid.step))**(1 / grid.ndim)
viscosity = (cs * cutoff)**2 * np.sqrt(
2 * np.trace(centered_s_ij.dot(centered_s_ij)))
viscosities_dict = {
offset: wrapped_interp_fn(viscosity, offset, v, dt).data
for offset in unique_offsets}
viscosities = [viscosities_dict[offset] for offset in s_ij_offsets]
return jax.tree_util.unflatten(jax.tree_util.tree_structure(s_ij), viscosities)
def evm_model(
v: GridVariableVector,
viscosity_fn: ViscosityFn,
) -> GridArrayVector:
"""Computes acceleration due to eddy viscosity turbulence model.
Eddy viscosity models compute a turbulence closure term as a divergence of
the subgrid-scale stress tensor, which is expressed as velocity dependent
viscosity times the rate of strain tensor. This module delegates computation
of the eddy-viscosity to `viscosity_fn` function.
Args:
v: velocity field.
viscosity_fn: function that computes viscosity values at the same offsets
as strain rate tensor provided as input.
Returns:
acceleration of the velocity field `v`.
"""
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError('evm_model only valid for periodic BC.')
grid = grids.consistent_grid(*v)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
s_ij = grids.GridArrayTensor([
[0.5 * (finite_differences.forward_difference(v[i], j) + # pylint: disable=g-complex-comprehension
finite_differences.forward_difference(v[j], i))
for j in range(grid.ndim)]
for i in range(grid.ndim)])
viscosity = viscosity_fn(s_ij, v)
tau = jax.tree_util.tree_map(lambda x, y: -2. * x * y, viscosity, s_ij)
return tuple(-finite_differences.divergence( # pylint: disable=g-complex-comprehension
tuple(grids.GridVariable(t, bc) # use velocity bc to compute diverence
for t in tau[i, :]))
for i in range(grid.ndim))
# TODO(dkochkov) remove when b/160947162 is resolved.
def implicit_evm_solve_with_diffusion(
v: GridVariableVector,
viscosity: float,
dt: float,
configured_evm_model: Callable, # pylint: disable=g-bare-generic
cg_kwargs: Optional[Mapping[str, Any]] = None
) -> GridVariableVector:
"""Implicit solve for eddy viscosity model combined with diffusion.
This method is intended to be used with `implicit_diffusion_navier_stokes` to
avoid potential numerical instabilities associated with fast diffusion modes.
Args:
v: current velocity field.
viscosity: constant viscosity coefficient.
dt: time step of implicit integration.
configured_evm_model: eddy viscosity model with specified `viscosity_fn`.
cg_kwargs: keyword arguments passed to jax.scipy.sparse.linalg.cg.
Returns:
velocity field advanced in time by `dt`.
"""
if cg_kwargs is None:
cg_kwargs = {}
cg_kwargs = dict(cg_kwargs)
cg_kwargs.setdefault('tol', 1e-6)
cg_kwargs.setdefault('atol', 1e-6)
if not boundaries.has_all_periodic_boundary_conditions(*v):
raise ValueError(
'implicit_evm_solve_with_diffusion only valid for periodic BC.')
bc = grids.unique_boundary_conditions(*v)
vector_laplacian = np.vectorize(finite_differences.laplacian)
# the arg v from the outer function.
def linear_op(velocity):
v_var = tuple(grids.GridVariable(u, bc) for u in velocity)
acceleration = configured_evm_model(v_var)
return tuple(
velocity - dt * (acceleration + viscosity * vector_laplacian(v_var)))
# We normally prefer fast diagonalization, but that requires an outer
# product structure for the linear operation, which doesn't hold here.
# TODO(shoyer): consider adding a preconditioner
v_prime, _ = jax.scipy.sparse.linalg.cg(linear_op, tuple(u.array for u in v),
**cg_kwargs)
return tuple(
grids.GridVariable(u_prime, u.bc) for u_prime, u in zip(v_prime, v))
def explicit_smagorinsky_navier_stokes(dt, cs, forcing, **kwargs):
"""Constructs explicit navier-stokes model with Smagorinsky viscosity term.
Navier-Stokes model that uses explicit time stepping for the eddy viscosity
model based on Smagorinsky closure term.
Args:
dt: time step to be performed.
cs: smagorinsky constant.
forcing: forcing term.
**kwargs: other keyword arguments to be passed to
`equations.semi_implicit_navier_stokes`.
Returns:
A function that performs a single step of time evolution of navier-stokes
equations with Smagorinsky turbulence model.
"""
viscosity_fn = functools.partial(
smagorinsky_viscosity, dt=dt, cs=cs)
smagorinsky_acceleration = functools.partial(
evm_model, viscosity_fn=viscosity_fn)
if forcing is None:
forcing = smagorinsky_acceleration
else:
forcing = forcings.sum_forcings(forcing, smagorinsky_acceleration)
return equations.semi_implicit_navier_stokes(dt=dt, forcing=forcing, **kwargs)
def implicit_smagorinsky_navier_stokes(dt, cs, **kwargs):
"""Constructs implicit navier-stokes model with Smagorinsky viscosity term.
Navier stokes model that uses implicit time stepping for the eddy viscosity
model based on Smagorinsky closure term. The implicit step is performed using
conjugate gradients and is combined with diffusion solve.
Args:
dt: time step to be performed.
cs: smagorinsky constant.
**kwargs: other keyword arguments to be passed to
`equations.implicit_diffusion_navier_stokes`.
Returns:
A function that performs a single step of time evolution of navier-stokes
equations with Smagorinsky turbulence model.
"""
viscosity_fn = functools.partial(
smagorinsky_viscosity, dt=dt, cs=cs)
smagorinsky_acceleration = functools.partial(
evm_model, viscosity_fn=viscosity_fn)
diffusion_solve_with_evm = functools.partial(
implicit_evm_solve_with_diffusion,
configured_evm_model=smagorinsky_acceleration)
return equations.implicit_diffusion_navier_stokes(
diffusion_solve=diffusion_solve_with_evm, dt=dt, **kwargs)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.subgrid_models."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import pressure
from jax_cfd.base import subgrid_models
from jax_cfd.base import test_util
import numpy as np
def periodic_grid_variable(data, offset, grid):
return grids.GridVariable(
array=grids.GridArray(data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
def zero_velocity_field(grid: grids.Grid) -> grids.GridVariableVector:
"""Returns an all-zero periodic velocity fields."""
return tuple(periodic_grid_variable(jnp.zeros(grid.shape), o, grid)
for o in grid.cell_faces)
def sinusoidal_velocity_field(grid: grids.Grid) -> grids.GridVariableVector:
"""Returns a divergence-free velocity flow on `grid`."""
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
vs = tuple(jnp.sin(2. * np.pi * g / s)
for g, s in zip(grid.mesh(), mesh_size))
return tuple(periodic_grid_variable(v, o, grid)
for v, o in zip(vs[1:] + vs[:1], grid.cell_faces))
def gaussian_force_field(grid: grids.Grid) -> grids.GridArrayVector:
"""Returns a 'Gaussian-shaped' force field in the 'x' direction."""
mesh = grid.mesh()
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
offsets = grid.cell_faces
v = [grids.GridArray(
jnp.exp(-sum([jnp.square(x / s - .5)
for x, s in zip(mesh, mesh_size)]) * 100.),
offsets[0], grid)]
for j in range(1, grid.ndim):
v.append(grids.GridArray(jnp.zeros(grid.shape), offsets[j], grid))
return tuple(v)
def gaussian_forcing(v: grids.GridVariableVector) -> grids.GridArrayVector:
"""Returns Gaussian field forcing."""
grid = grids.consistent_grid(*v)
return gaussian_force_field(grid)
def momentum(v: grids.GridVariableVector, density: float):
"""Returns the momentum due to velocity field `v`."""
grid = grids.consistent_grid(*v)
return jnp.array([u.data for u in v]).sum() * density * jnp.array(
grid.step).prod()
def _convect_upwind(v: grids.GridVariableVector) -> grids.GridArrayVector:
return tuple(advection.advect_upwind(u, v) for u in v)
class SubgridModelsTest(test_util.TestCase):
def test_smagorinsky_viscosity(self):
grid = grids.Grid((3, 3))
v = (periodic_grid_variable(jnp.zeros(grid.shape), (1, 0.5), grid),
periodic_grid_variable(jnp.zeros(grid.shape), (0.5, 1), grid))
c00 = grids.GridArray(jnp.zeros(grid.shape), offset=(0, 0), grid=grid)
c01 = grids.GridArray(jnp.zeros(grid.shape), offset=(0, 1), grid=grid)
c10 = grids.GridArray(jnp.zeros(grid.shape), offset=(1, 0), grid=grid)
c11 = grids.GridArray(jnp.zeros(grid.shape), offset=(1, 1), grid=grid)
s_ij = grids.GridArrayTensor(np.array([[c00, c01], [c10, c11]]))
viscosity = subgrid_models.smagorinsky_viscosity(
s_ij=s_ij, v=v, dt=0.1, cs=0.2)
self.assertIsInstance(viscosity, grids.GridArrayTensor)
self.assertEqual(viscosity.shape, (2, 2))
self.assertAllClose(viscosity[0, 0], c00)
self.assertAllClose(viscosity[0, 1], c01)
self.assertAllClose(viscosity[1, 0], c10)
self.assertAllClose(viscosity[1, 1], c11)
def test_evm_model(self):
grid = grids.Grid((3, 3))
v = (
periodic_grid_variable(jnp.zeros(grid.shape), (1, 0.5), grid),
periodic_grid_variable(jnp.zeros(grid.shape), (0.5, 1), grid))
viscosity_fn = functools.partial(
subgrid_models.smagorinsky_viscosity, dt=1.0, cs=0.2)
acceleration = subgrid_models.evm_model(v, viscosity_fn)
self.assertIsInstance(acceleration, tuple)
self.assertLen(acceleration, 2)
self.assertAllClose(acceleration[0], v[0].array)
self.assertAllClose(acceleration[1], v[1].array)
@parameterized.named_parameters(
dict(
testcase_name='sinusoidal_velocity_base',
cs=0.0,
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=advection.convect_linear,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=2e-3),
dict(
testcase_name='gaussian_force_upwind_with_subgrid_model',
cs=0.12,
velocity=zero_velocity_field,
forcing=gaussian_forcing,
shape=(40, 40, 40),
step=(1., 1., 1.),
density=1.,
viscosity=0,
convect=_convect_upwind,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=100,
divergence_atol=1e-4,
momentum_atol=1e-4),
dict(
testcase_name='sinusoidal_velocity_with_subgrid_model',
cs=0.12,
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=advection.convect_linear,
pressure_solve=pressure.solve_fast_diag,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
)
def test_divergence_and_momentum(
self,
cs,
velocity,
forcing,
shape,
step,
density,
viscosity,
convect,
pressure_solve,
dt,
time_steps,
divergence_atol,
momentum_atol,
):
grid = grids.Grid(shape, step)
kwargs = dict(
density=density,
viscosity=viscosity,
cs=cs,
dt=dt,
grid=grid,
convect=convect,
pressure_solve=pressure_solve,
forcing=forcing)
# Explicit and implicit navier-stokes solvers:
explicit_eq = subgrid_models.explicit_smagorinsky_navier_stokes(**kwargs)
implicit_eq = subgrid_models.implicit_smagorinsky_navier_stokes(**kwargs)
v_initial = velocity(grid)
v_final = funcutils.repeated(explicit_eq, time_steps)(v_initial)
# TODO(dkochkov) consider adding more thorough tests for these models.
with self.subTest('divergence free'):
divergence = fd.divergence(v_final)
self.assertLess(jnp.max(divergence.data), divergence_atol)
with self.subTest('conservation of momentum'):
initial_momentum = momentum(v_initial, density)
final_momentum = momentum(v_final, density)
if forcing is not None:
expected_change = (
jnp.array([f.data for f in forcing(v_initial)]).sum() *
jnp.array(grid.step).prod() * dt * time_steps)
else:
expected_change = 0
expected_momentum = initial_momentum + expected_change
self.assertAllClose(expected_momentum, final_momentum, atol=momentum_atol)
with self.subTest('explicit-implicit consistency'):
v_final_2 = funcutils.repeated(implicit_eq, time_steps)(v_initial)
for axis in range(grid.ndim):
self.assertAllClose(v_final[axis], v_final_2[axis], atol=1e-4,
err_msg=f'axis={axis}')
if __name__ == '__main__':
absltest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment