Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared test utilities."""
from absl.testing import parameterized
from jax import config
from jax_cfd.base import grids
import numpy as np
config.parse_flags_with_absl()
class TestCase(parameterized.TestCase):
"""TestCase with assertions for arrays and grids.AlignedArray."""
def _check_and_remove_alignment_and_grid(self, *arrays):
"""Check that array-like data values and other attributes match.
If args type is GridArray, verify their offsets and grids match.
If args type is GridVariable, verify their offsets, grids, and bc match.
Args:
*arrays: one or more Array, GridArray or GridVariable, but they all be the
same type.
Returns:
The data-only arrays, with other attributes removed.
"""
is_gridarray = [isinstance(array, grids.GridArray) for array in arrays]
# GridArray
if any(is_gridarray):
self.assertTrue(
all(is_gridarray), msg=f'arrays have mixed types: {arrays}')
try:
grids.consistent_offset(*arrays)
except grids.InconsistentOffsetError as e:
raise AssertionError(str(e)) from None
try:
grids.consistent_grid(*arrays)
except grids.InconsistentGridError as e:
raise AssertionError(str(e)) from None
arrays = tuple(array.data for array in arrays)
# GridVariable
is_gridvariable = [
isinstance(array, grids.GridVariable) for array in arrays
]
if any(is_gridvariable):
self.assertTrue(
all(is_gridvariable), msg=f'arrays have mixed types: {arrays}')
try:
grids.consistent_offset(*arrays)
except grids.InconsistentOffsetError as e:
raise AssertionError(str(e)) from None
try:
grids.consistent_grid(*arrays)
except grids.InconsistentGridError as e:
raise AssertionError(str(e)) from None
try:
grids.unique_boundary_conditions(*arrays)
except grids.InconsistentBoundaryConditionsError as e:
raise AssertionError(str(e)) from None
arrays = tuple(array.array.data for array in arrays)
return arrays
# pylint: disable=unbalanced-tuple-unpacking
def assertArrayEqual(self, expected, actual, **kwargs):
expected, actual = self._check_and_remove_alignment_and_grid(
expected, actual)
np.testing.assert_array_equal(expected, actual, **kwargs)
def assertAllClose(self, expected, actual, **kwargs):
expected, actual = self._check_and_remove_alignment_and_grid(
expected, actual)
np.testing.assert_allclose(expected, actual, **kwargs)
# pylint: enable=unbalanced-tuple-unpacking
\ No newline at end of file
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Time stepping for Navier-Stokes equations."""
import dataclasses
from typing import Callable, Sequence, TypeVar
import jax
import tree_math
PyTreeState = TypeVar("PyTreeState")
TimeStepFn = Callable[[PyTreeState], PyTreeState]
class ExplicitNavierStokesODE:
"""Spatially discretized version of Navier-Stokes.
The equation is given by:
∂u/∂t = explicit_terms(u)
0 = incompressibility_constraint(u)
"""
def __init__(self, explicit_terms, pressure_projection):
self.explicit_terms = explicit_terms
self.pressure_projection = pressure_projection
def explicit_terms(self, state):
"""Explicitly evaluate the ODE."""
raise NotImplementedError
def pressure_projection(self, state):
"""Enforce the incompressibility constraint."""
raise NotImplementedError
@dataclasses.dataclass
class ButcherTableau:
a: Sequence[Sequence[float]]
b: Sequence[float]
# TODO(shoyer): add c, when we support time-dependent equations.
def __post_init__(self):
if len(self.a) + 1 != len(self.b):
raise ValueError("inconsistent Butcher tableau")
def navier_stokes_rk(
tableau: ButcherTableau,
equation: ExplicitNavierStokesODE,
time_step: float,
) -> TimeStepFn:
"""Create a forward Runge-Kutta time-stepper for incompressible Navier-Stokes.
This function implements the reference method (equations 16-21), rather than
the fast projection method, from:
"Fast-Projection Methods for the Incompressible Navier–Stokes Equations"
Fluids 2020, 5, 222; doi:10.3390/fluids5040222
Args:
tableau: Butcher tableau.
equation: equation to use.
time_step: overall time-step size.
Returns:
Function that advances one time-step forward.
"""
# pylint: disable=invalid-name
dt = time_step
F = tree_math.unwrap(equation.explicit_terms)
P = tree_math.unwrap(equation.pressure_projection)
a = tableau.a
b = tableau.b
num_steps = len(b)
@tree_math.wrap
def step_fn(u0):
u = [None] * num_steps
k = [None] * num_steps
u[0] = u0
k[0] = F(u0)
for i in range(1, num_steps):
u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j])
u[i] = P(u_star)
k[i] = F(u[i])
u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j])
u_final = P(u_star)
return u_final
return step_fn
def forward_euler(
equation: ExplicitNavierStokesODE, time_step: float,
) -> TimeStepFn:
return jax.named_call(
navier_stokes_rk(
ButcherTableau(a=[], b=[1]),
equation,
time_step),
name="forward_euler",
)
def midpoint_rk2(
equation: ExplicitNavierStokesODE, time_step: float,
) -> TimeStepFn:
return jax.named_call(
navier_stokes_rk(
ButcherTableau(a=[[1/2]], b=[0, 1]),
equation=equation,
time_step=time_step,
),
name="midpoint_rk2",
)
def heun_rk2(
equation: ExplicitNavierStokesODE, time_step: float,
) -> TimeStepFn:
return jax.named_call(
navier_stokes_rk(
ButcherTableau(a=[[1]], b=[1/2, 1/2]),
equation=equation,
time_step=time_step,
),
name="heun_rk2",
)
def classic_rk4(
equation: ExplicitNavierStokesODE, time_step: float,
) -> TimeStepFn:
return jax.named_call(
navier_stokes_rk(
ButcherTableau(a=[[1/2], [0, 1/2], [0, 0, 1]],
b=[1/6, 1/3, 1/3, 1/6]),
equation=equation,
time_step=time_step,
),
name="classic_rk4",
)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for time_stepping."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import config
import jax.numpy as jnp
from jax_cfd.base import funcutils
from jax_cfd.base import time_stepping
import numpy as np
def harmonic_oscillator(x0, t):
theta = jnp.arctan(x0[0] / x0[1])
r = jnp.linalg.norm(x0, ord=2, axis=0)
return r * jnp.stack([jnp.sin(t + theta), jnp.cos(t + theta)])
ALL_TEST_PROBLEMS = [
dict(testcase_name='_harmonic_oscillator_explicit',
explicit_terms=lambda x: jnp.stack([x[1], -x[0]]),
pressure_projection=lambda x: x,
dt=1e-2,
inner_steps=20,
outer_steps=5,
initial_state=np.ones(2),
closed_form=harmonic_oscillator,
tolerances=[1e-2, 3e-5, 3e-5, 4e-7]),
]
ALL_TIME_STEPPERS = [
time_stepping.forward_euler,
time_stepping.midpoint_rk2,
time_stepping.heun_rk2,
time_stepping.classic_rk4,
]
class TimeSteppingTest(parameterized.TestCase):
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
pressure_projection,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = time_stepping.ExplicitNavierStokesODE(
explicit_terms, pressure_projection)
step_fn = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(step_fn, inner_steps), outer_steps)
_, actual = integrator(initial_state)
np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
if __name__ == '__main__':
config.update('jax_enable_x64', True)
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Descriptions and analytical solutions for validation problems."""
import abc
from typing import Optional, Sequence, Tuple
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
import numpy as np
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariableVector = grids.GridVariableVector
Offsets = Sequence[Sequence[float]]
class Problem(metaclass=abc.ABCMeta):
"""An abstract class for Navier-Stokes problems."""
@property
def grid(self):
return self._grid # pytype: disable=attribute-error # bind-properties
@property
def density(self):
return self._density # pytype: disable=attribute-error # bind-properties
@property
def viscosity(self):
return self._viscosity # pytype: disable=attribute-error # bind-properties
def force(self,
offsets: Optional[Offsets] = None) -> Optional[GridArrayVector]:
del offsets # Unused
return None
@abc.abstractmethod
def velocity(self,
t: float,
offsets: Optional[Offsets] = None) -> GridVariableVector:
pass
class TaylorGreen(Problem):
"""2D Taylor Green vortices with analytic solution for velocity.
See https://en.wikipedia.org/wiki/Taylor%E2%80%93Green_vortex.
"""
# TODO(jamieas): consider parameterizing problems in terms of Reynolds
# number.
def __init__(self,
shape: Tuple[int, int],
density: float = 1,
viscosity: float = 0,
kx: float = 1,
ky: float = 1):
self._grid = grids.Grid(shape=shape,
domain=[(0., 2. * np.pi),
(0., 2. * np.pi)])
self._density = density
self._viscosity = viscosity
self._kx = kx
self._ky = ky
def velocity(
self,
t: float = 0,
offsets: Optional[Offsets] = None) -> GridVariableVector:
"""Returns an analytic solution for velocity at time `t`."""
if offsets is None:
offsets = self.grid.cell_faces
scale = jnp.exp(-2 * self.viscosity * t)
ux, uy = self.grid.mesh(offsets[0])
u = grids.GridVariable(
array=grids.GridArray(
data=scale * jnp.cos(self._kx * ux) * jnp.sin(self._ky * uy),
offset=offsets[0],
grid=self.grid),
bc=boundaries.periodic_boundary_conditions(self.grid.ndim))
vx, vy = self.grid.mesh(offsets[1])
v = grids.GridVariable(
array=grids.GridArray(
data=-scale * jnp.sin(self._kx * vx) * jnp.cos(self._ky * vy),
offset=offsets[1],
grid=self.grid),
bc=boundaries.periodic_boundary_conditions(self.grid.ndim))
return (u, v)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Validation tests for JAX-CFD."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax.numpy as jnp
from jax_cfd.base import advection
from jax_cfd.base import diffusion
from jax_cfd.base import equations
from jax_cfd.base import funcutils
from jax_cfd.base import test_util
from jax_cfd.base import time_stepping
from jax_cfd.base import validation_problems
class ValidationTests(test_util.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='_TaylorGreen_SemiImplicitNavierStokes',
problem=validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=1e-3),
solver=functools.partial(
equations.semi_implicit_navier_stokes,
convect=advection.convect_linear),
implicit_diffusion=False,
max_courant_number=.1,
time=10.,
atol=1e-5),
dict(
testcase_name='_TaylorGreen_SemiImplicitNavierStokes_rk1',
problem=validation_problems.TaylorGreen(
shape=(512, 512), density=1., viscosity=1e-2),
solver=functools.partial(
equations.semi_implicit_navier_stokes,
convect=advection.convect_linear,
time_stepper=time_stepping.forward_euler,
),
implicit_diffusion=False,
max_courant_number=.1,
time=40.,
atol=6e-6),
dict(
testcase_name='_TaylorGreen_SemiImplicitNavierStokes_rk4',
problem=validation_problems.TaylorGreen(
shape=(512, 512), density=1., viscosity=1e-2),
solver=functools.partial(
equations.semi_implicit_navier_stokes,
convect=advection.convect_linear,
time_stepper=time_stepping.classic_rk4,
),
implicit_diffusion=False,
max_courant_number=.1,
time=40.,
atol=8e-7),
dict(
testcase_name='_TaylorGreen_ImplicitDiffusionNavierStokes_matmul',
problem=validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=1e-3),
solver=functools.partial(
equations.implicit_diffusion_navier_stokes,
convect=advection.convect_linear,
diffusion_solve=functools.partial(
diffusion.solve_fast_diag, implementation='matmul'),
),
implicit_diffusion=True,
max_courant_number=.1,
time=10.,
atol=3e-5),
dict(
testcase_name='_TaylorGreen_ImplicitDiffusionNavierStokes_fft',
problem=validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=1e-3),
solver=functools.partial(
equations.implicit_diffusion_navier_stokes,
convect=advection.convect_linear,
diffusion_solve=functools.partial(
diffusion.solve_fast_diag, implementation='fft'),
),
implicit_diffusion=True,
max_courant_number=.1,
time=10.,
atol=4e-5),
dict(
testcase_name='_TaylorGreen_ImplicitDiffusionNavierStokes_rfft',
problem=validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=1e-3),
solver=functools.partial(
equations.implicit_diffusion_navier_stokes,
convect=advection.convect_linear,
diffusion_solve=functools.partial(
diffusion.solve_fast_diag, implementation='rfft'),
),
implicit_diffusion=True,
max_courant_number=.1,
time=10.,
atol=4e-5),
dict(
testcase_name='_TaylorGreen_ImplicitDiffusionNavierStokes_cg',
problem=validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=1e-3),
solver=functools.partial(
equations.implicit_diffusion_navier_stokes,
convect=advection.convect_linear,
diffusion_solve=functools.partial(
diffusion.solve_cg, atol=1e-6, maxiter=512)),
implicit_diffusion=True,
max_courant_number=.1,
time=10.,
atol=3e-5),
dict(
testcase_name='_TaylorGreen_ImplicitDiffusionNavierStokes_viscous',
problem=validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=0.5),
solver=functools.partial(
equations.implicit_diffusion_navier_stokes,
convect=advection.convect_linear),
implicit_diffusion=True,
max_courant_number=.5,
time=1.0,
atol=6e-4,
),
)
def test_accuracy(self, problem, solver, implicit_diffusion,
max_courant_number, time, atol):
"""Test the accuracy of `solver` on the given `problem`.
Args:
problem: an instance of `validation_problems.Problem`.
solver: a callable that takes `density`, `viscosity`, `dt`, `grid`, and
`steps`. It returns a callable that takes `velocity`,
`pressure_correction` and `force` and returns updated versions of these
values at the next time step.
implicit_diffusion: whether or not the solver models diffusion implicitly.
max_courant_number: a float used to choose the size of the time step `dt`
according to the Courant-Friedrichs-Lewy condition. See
https://en.wikipedia.org/wiki/Courant-Friedrichs-Lewy_condition.
time: the amount of time to run the simulation for.
atol: absolute error tolerance per entry.
"""
v = problem.velocity(0.)
dt = equations.dynamic_time_step(
v, max_courant_number, problem.viscosity, problem.grid,
implicit_diffusion)
steps = int(jnp.ceil(time / dt))
navier_stokes = solver(density=problem.density,
viscosity=problem.viscosity,
dt=dt,
grid=problem.grid)
v_computed = funcutils.repeated(navier_stokes, steps)(v)
v_analytic = problem.velocity(time)
for u_c, u_a in zip(v_computed, v_analytic):
self.assertAllClose(u_c, u_a, atol=atol)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Collocated grid versions of "base" physics routines for JAX-CFD."""
import jax_cfd.collocated.advection
import jax_cfd.collocated.diffusion
import jax_cfd.collocated.equations
import jax_cfd.collocated.initial_conditions
import jax_cfd.collocated.pressure
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for functionality related to advection."""
from typing import Optional, Tuple
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
def advect_linear(c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None) -> GridArray:
"""Computes advection for collocated scalar `c` with velocity `v`."""
del dt
flux_bc = [
boundaries.get_advection_flux_bc_from_velocity_and_scalar(u, c, direction)
for direction, u in enumerate(v)
]
flux = tuple(flux_bc[axis].impose_bc(c.array * v[axis].array)
for axis in range(c.grid.ndim))
return -fd.centered_divergence(flux)
def _velocities_to_flux(v: GridVariableVector) -> Tuple[GridVariableVector]:
"""Computes the cell-centered convective flux for a velocity field.
This is the flux associated with the nonlinear term `vv` for velocity `v`.
The boundary condition on the flux is inherited from `v`.
Args:
v: velocity vector.
Returns:
A tuple of tuples `flux` of `GridVariable`s with the values `v[i]*v[j]`
"""
ndim = len(v)
flux = [tuple() for _ in range(ndim)]
ndim = len(v)
flux = [tuple() for _ in range(ndim)]
for i in range(ndim):
for j in range(ndim):
if i <= j:
bc = boundaries.get_advection_flux_bc_from_velocity_and_scalar(
v[j], v[i], j)
flux[i] += (bc.impose_bc(v[i].array * v[j].array),)
else:
flux[i] += (flux[j][i],)
return tuple(flux)
def convect_linear(v: GridVariableVector) -> GridArrayVector:
"""Computes convection/self-advection of the velocity field `v`.
Args:
v: velocity vector.
Returns:
A tuple containing the time derivative of each component of `v` due to
convection.
"""
fluxes = _velocities_to_flux(v)
return tuple(-fd.centered_divergence(flux) for flux in fluxes)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax_cfd.advection."""
import functools
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.collocated import advection
import numpy as np
def _cos_velocity(grid):
offset = grid.cell_center
mesh = grid.mesh(offset)
mesh_size = jnp.array(grid.shape) * jnp.array(grid.step)
v = tuple(grids.GridArray(jnp.cos(2. * np.pi * x / s), offset, grid)
for x, s in zip(mesh, mesh_size))
return v
def _euler_step(advection_method):
def step(c, v, dt):
c_new = c.array + dt * advection_method(c, v, dt)
return c.bc.impose_bc(c_new)
return step
class AdvectionTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_linear_1D',
shape=(101,),
advection_method=advection.advect_linear,
convection_method=advection.convect_linear),
dict(testcase_name='_linear_3D',
shape=(101, 101, 101),
advection_method=advection.advect_linear,
convection_method=advection.convect_linear)
)
def test_convection_vs_advection(
self, shape, advection_method, convection_method,
):
"""Exercises self-advection, check equality with advection on components."""
step = tuple(1. / s for s in shape)
grid = grids.Grid(shape, step)
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in _cos_velocity(grid))
self_advected = convection_method(v)
for u, du in zip(v, self_advected):
advected_component = advection_method(u, v)
self.assertAllClose(advected_component, du)
@parameterized.named_parameters(
dict(
testcase_name='dichlet_advect',
shape=(101,),
method=_euler_step(advection.advect_linear)),)
def test_mass_conservation_dirichlet(self, shape, method):
cfl_number = 0.1
dt = cfl_number / shape[0]
num_steps = 1000
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),))
def u(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(-jnp.sin(jnp.pi * x), (0.5,), grid)
def c0(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(x, (0.5,), grid)
v = (bc.impose_bc(u(grid)),)
c = c_bc.impose_bc(c0(grid))
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_mass = np.sum(c.data)
for _ in range(num_steps):
ct = advect(ct)
current_total_mass = np.sum(ct.data)
self.assertAllClose(current_total_mass, initial_mass, atol=1e-6)
@parameterized.named_parameters(
dict(
testcase_name='linear_1d_neumann',
shape=(1000,),
method=advection.advect_linear),)
def test_neumann_bc_one_step(self, shape, method):
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.neumann_boundary_conditions(grid.ndim)
c_bc = boundaries.neumann_boundary_conditions(grid.ndim)
def u(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(jnp.cos(jnp.pi * x), (0.5,), grid)
def c0(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(jnp.cos(jnp.pi * x), (0.5,), grid)
def dcdt(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(jnp.pi * jnp.sin(2 * jnp.pi * x), (0.5,), grid)
v = (bc.impose_bc(u(grid)),)
c = c_bc.impose_bc(c0(grid))
advect = jax.jit(functools.partial(method, v=v))
ct = advect(c)
self.assertAllClose(ct, dcdt(grid), atol=1e-4)
if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(pnorgaard) Implement bicgstab for non-symmetric operators
"""Module for functionality related to diffusion."""
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
GridArray = grids.GridArray
GridVariable = grids.GridVariable
# TODO(pnorgaard): Implement the equivalent expanded 5-point laplacian operator
def diffuse(c: GridVariable, nu: float) -> GridArray:
"""Returns the rate of change in a concentration `c` due to diffusion."""
if not boundaries.has_all_periodic_boundary_conditions(c):
raise ValueError('Expected periodic BC')
gradient = fd.central_difference(c, axis=None)
gradient = tuple(grids.GridVariable(g, c.bc) for g in gradient)
return nu * fd.centered_divergence(gradient)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment