Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pseudospectral equations."""
import dataclasses
from typing import Callable, Optional
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.spectral import forcings as spectral_forcings
from jax_cfd.spectral import time_stepping
from jax_cfd.spectral import types
from jax_cfd.spectral import utils as spectral_utils
TimeDependentForcingFn = Callable[[float], types.Array]
RandomSeed = int
ForcingModule = Callable[[grids.Grid, RandomSeed], TimeDependentForcingFn]
@dataclasses.dataclass
class KuramotoSivashinsky(time_stepping.ImplicitExplicitODE):
"""Kuramoto–Sivashinsky (KS) equation split in implicit and explicit parts.
The KS equation is
u_t = - u_xx - u_xxxx - 1/2 * (u ** 2)_x
Implicit parts are the linear terms and explicit parts are the non-linear
terms.
Attributes:
grid: underlying grid of the process
smooth: smooth the non-linear term using the 3/2-rule
"""
grid: grids.Grid
smooth: bool = True
def __post_init__(self):
self.kx, = self.grid.rfft_axes()
self.two_pi_i_k = 2j * jnp.pi * self.kx
self.linear_term = -self.two_pi_i_k ** 2 - self.two_pi_i_k ** 4
self.rfft = spectral_utils.truncated_rfft if self.smooth else jnp.fft.rfft
self.irfft = spectral_utils.padded_irfft if self.smooth else jnp.fft.irfft
def explicit_terms(self, uhat):
"""Non-linear parts of the equation, namely `- 1/2 * (u ** 2)_x`."""
uhat_squared = self.rfft(jnp.square(self.irfft(uhat)))
return -0.5 * self.two_pi_i_k * uhat_squared
def implicit_terms(self, uhat):
"""Linear parts of the equation, namely `- u_xx - u_xxxx`."""
return self.linear_term * uhat
def implicit_solve(self, uhat, time_step):
"""Solves for `implicit_terms`, implicitly."""
# TODO(dresdner) the same for all linear terms. generalize/refactor?
return 1 / (1 - time_step * self.linear_term) * uhat
@dataclasses.dataclass
class ForcedBurgersEquation(time_stepping.ImplicitExplicitODE):
"""Burgers' Equation with the option to add a time-dependent forcing function."""
viscosity: float
grid: grids.Grid
seed: int = 0
forcing_module: Optional[
ForcingModule] = spectral_forcings.random_forcing_module
_forcing_fn = None
def __post_init__(self):
self.kx, = self.grid.rfft_axes()
self.two_pi_i_k = 2j * jnp.pi * self.kx
self.linear_term = self.viscosity * self.two_pi_i_k ** 2
self.rfft = spectral_utils.truncated_rfft
self.irfft = spectral_utils.padded_irfft
if self.forcing_module is None:
self._forcing_fn = lambda t: jnp.zeros(1)
else:
self._forcing_fn = self.forcing_module(self.grid, self.seed)
def explicit_terms(self, state):
uhat, t = state
dudx = self.two_pi_i_k * uhat
f = self._forcing_fn(t)
fhat = jnp.fft.rfft(f)
advection = - self.rfft(self.irfft(uhat) * self.irfft(dudx))
return (fhat + advection, 1.0)
def implicit_terms(self, state):
uhat, _ = state
return (self.linear_term * uhat, 0.0)
def implicit_solve(self, state, time_step):
uhat, t = state
return (1 / (1 - time_step * self.linear_term) * uhat, t)
def BurgersEquation(viscosity: float, grid: grids.Grid, seed: int = 0):
"""Standard, unforced Burgers' equation."""
return ForcedBurgersEquation(
viscosity=viscosity, grid=grid, seed=seed, forcing_module=None)
# pylint: disable=invalid-name
def _get_grid_variable(arr,
grid,
bc=boundaries.periodic_boundary_conditions(2),
offset=(0.5, 0.5)):
return grids.GridVariable(grids.GridArray(arr, offset, grid), bc)
@dataclasses.dataclass
class NavierStokes2D(time_stepping.ImplicitExplicitODE):
"""Breaks the Navier-Stokes equation into implicit and explicit parts.
Implicit parts are the linear terms and explicit parts are the non-linear
terms.
Attributes:
viscosity: strength of the diffusion term
grid: underlying grid of the process
smooth: smooth the advection term using the 2/3-rule.
forcing_fn: forcing function, if None then no forcing is used.
drag: strength of the drag. Set to zero for no drag.
"""
viscosity: float
grid: grids.Grid
drag: float = 0.
smooth: bool = True
forcing_fn: Optional[Callable[[grids.Grid], forcings.ForcingFn]] = None
_forcing_fn_with_grid = None
def __post_init__(self):
self.kx, self.ky = self.grid.rfft_mesh()
self.laplace = (jnp.pi * 2j)**2 * (self.kx**2 + self.ky**2)
self.filter_ = spectral_utils.brick_wall_filter_2d(self.grid)
self.linear_term = self.viscosity * self.laplace - self.drag
# setup the forcing function with the caller-specified grid.
if self.forcing_fn is not None:
self._forcing_fn_with_grid = self.forcing_fn(self.grid)
def explicit_terms(self, vorticity_hat):
velocity_solve = spectral_utils.vorticity_to_velocity(self.grid)
vxhat, vyhat = velocity_solve(vorticity_hat)
vx, vy = jnp.fft.irfftn(vxhat), jnp.fft.irfftn(vyhat)
grad_x_hat = 2j * jnp.pi * self.kx * vorticity_hat
grad_y_hat = 2j * jnp.pi * self.ky * vorticity_hat
grad_x, grad_y = jnp.fft.irfftn(grad_x_hat), jnp.fft.irfftn(grad_y_hat)
advection = -(grad_x * vx + grad_y * vy)
advection_hat = jnp.fft.rfftn(advection)
if self.smooth is not None:
advection_hat *= self.filter_
terms = advection_hat
if self.forcing_fn is not None:
fx, fy = self._forcing_fn_with_grid((_get_grid_variable(vx, self.grid),
_get_grid_variable(vy, self.grid)))
fx_hat, fy_hat = jnp.fft.rfft2(fx.data), jnp.fft.rfft2(fy.data)
terms += spectral_utils.spectral_curl_2d((self.kx, self.ky),
(fx_hat, fy_hat))
return terms
def implicit_terms(self, vorticity_hat):
return self.linear_term * vorticity_hat
def implicit_solve(self, vorticity_hat, time_step):
return 1 / (1 - time_step * self.linear_term) * vorticity_hat
# pylint: disable=g-doc-args,g-doc-return-or-yield,invalid-name
def ForcedNavierStokes2D(viscosity, grid, smooth):
"""Sets up the flow that is used in Kochkov et al. [1].
The authors of [1] based their work on Boffetta et al. [2].
References:
[1] Machine learning–accelerated computational fluid dynamics. Dmitrii
Kochkov, Jamie A. Smith, Ayya Alieva, Qing Wang, Michael P. Brenner, Stephan
Hoyer Proceedings of the National Academy of Sciences May 2021, 118 (21)
e2101784118; DOI: 10.1073/pnas.2101784118.
https://doi.org/10.1073/pnas.2101784118
[2] Boffetta, Guido, and Robert E. Ecke. "Two-dimensional turbulence."
Annual review of fluid mechanics 44 (2012): 427-451.
https://doi.org/10.1146/annurev-fluid-120710-101240
"""
wave_number = 4
offsets = ((0, 0), (0, 0))
# pylint: disable=g-long-lambda
forcing_fn = lambda grid: forcings.kolmogorov_forcing(
grid, k=wave_number, offsets=offsets)
return NavierStokes2D(
viscosity,
grid,
drag=0.1,
smooth=smooth,
forcing_fn=forcing_fn)
@dataclasses.dataclass
class NonlinearSchrodinger(time_stepping.ImplicitExplicitODE):
"""Nonlinear schrodinger equation split in implicit and explicit parts.
The NLS equation is
`psi_t = -i psi_xx/8 - i|psi|^2 psi/2`
Attributes:
grid: underlying grid of the process
smooth: smooth the non-linear by upsampling 2x in fourier and truncating
"""
grid: grids.Grid
smooth: bool = True
def __post_init__(self):
self.kx, = self.grid.fft_axes()
assert len(self.kx) % 2 == 0, "Odd grid sizes not supported, try N even"
self.two_pi_i_k = 2j * jnp.pi * self.kx
self.fft = spectral_utils.truncated_fft_2x if self.smooth else jnp.fft.fft
self.ifft = spectral_utils.padded_ifft_2x if self.smooth else jnp.fft.ifft
def explicit_terms(self, psihat):
"""Non-linear part of the equation `-i|psi|^2 psi/2`."""
psi = self.ifft(psihat)
ipsi_cubed = 1j * psi * jnp.abs(psi)**2
ipsi_cubed_hat = self.fft(ipsi_cubed)
return -ipsi_cubed_hat / 2
def implicit_terms(self, psihat):
"""The diffusion term `-i psi_xx/8` to be handled implicitly."""
return -1j * psihat * self.two_pi_i_k**2 / 8
def implicit_solve(self, psihat, time_step):
"""Solves for `implicit_terms`, implicitly."""
return psihat / (1 - time_step * (-1j * self.two_pi_i_k**2 / 8))
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for spectral equations."""
from typing import Tuple
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
import jax_cfd.base as cfd
from jax_cfd.base import finite_differences
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.spectral import equations as spectral_equations
from jax_cfd.spectral import time_stepping
ALL_TIME_STEPPERS = [
time_stepping.backward_forward_euler,
time_stepping.crank_nicolson_rk2,
time_stepping.crank_nicolson_rk3,
time_stepping.crank_nicolson_rk4,
]
ALL_TIME_STEPPERS = [
dict(testcase_name='_' + s.__name__, time_stepper=s)
for s in ALL_TIME_STEPPERS
]
def roll(arr, offset: Tuple[int]):
"""Rolls an n-dim arr by offset."""
assert len(offset) == len(arr.shape)
for i, o in enumerate(offset):
arr = jnp.roll(arr, o, axis=i)
return arr
def get_grid(resolution, domain=(0, 2*jnp.pi)):
return grids.Grid((resolution,), domain=(domain,))
def get_zeros_initial_condition(grid, dtype=jnp.complex64):
n, = grid.shape
return jnp.zeros(n // 2 + 1, dtype=dtype)
def get_sine_initial_condition(grid):
xs, = grid.axes(offset=(0,))
return jnp.fft.rfft(jnp.sin(xs))
class EquationsTest1D(test_util.TestCase):
def test_ks_equation(self):
"""Test that the KS equation (1) does not explode and (2) conserves momentum."""
size = 128
outer_steps = 2100
length = 10. * jnp.pi
grid = cfd.grids.Grid((size,), domain=((0, length),))
dx, = grid.step
dt = dx / length
# TODO(dresdner) make a parameterized test
for smooth in [True, False]:
step_fn = time_stepping.backward_forward_euler(
spectral_equations.KuramotoSivashinsky(grid, smooth=smooth), dt)
rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))
xs, = grid.axes()
v0 = jnp.cos((1 / length) * xs)
v0 = jnp.fft.rfft(v0)
_, trajectory = jax.device_get(rollout_fn(v0))
real_space_trajectory = jnp.fft.irfft(trajectory).real
# ensure no explosion
self.assertTrue(jnp.all(real_space_trajectory < 1e5))
# conservation of momentum: momentum does not change over time
initial_momentum = real_space_trajectory[0].sum()
self.assertAllClose(
initial_momentum, jnp.sum(real_space_trajectory, axis=1), atol=1e-3)
@parameterized.named_parameters(
dict(
testcase_name='one_step_zeros',
viscosity=0.01,
grid=get_grid(128),
time_step=0.01,
initial_condition_fn=get_zeros_initial_condition,
num_steps=1,
),
dict(
testcase_name='one_step_sine',
viscosity=0.01,
grid=get_grid(128),
time_step=0.01,
initial_condition_fn=get_sine_initial_condition,
num_steps=1),
dict(
testcase_name='many_step_zeros',
viscosity=0.01,
grid=get_grid(128),
time_step=0.01,
initial_condition_fn=get_zeros_initial_condition,
num_steps=1000),
dict(
testcase_name='many_step_sine',
viscosity=0.01,
grid=get_grid(128),
time_step=0.01,
initial_condition_fn=get_sine_initial_condition,
num_steps=1000),
)
def test_burgers_equation(self, viscosity, grid, time_step,
initial_condition_fn, num_steps):
"""Check that the trajectories don't give NaNs."""
eq = spectral_equations.BurgersEquation(viscosity=viscosity, grid=grid)
step_fn = time_stepping.crank_nicolson_rk2(eq, time_step)
step_fn = cfd.funcutils.repeated(step_fn, num_steps)
uhat0 = initial_condition_fn(grid)
t0 = 0.0
uhat1, _ = step_fn((uhat0, t0))
self.assertFalse(jnp.isnan(uhat1).any())
@parameterized.named_parameters(
dict(
testcase_name='one_step_zeros',
viscosity=0.01,
grid=get_grid(128),
time_step=0.01,
initial_condition_fn=get_zeros_initial_condition,
num_steps=1,
),
dict(
testcase_name='many_step_zeros',
viscosity=0.01,
grid=get_grid(128),
time_step=0.01,
initial_condition_fn=get_zeros_initial_condition,
num_steps=1000),
)
def test_forced_burgers_equation(self, viscosity, grid, time_step,
initial_condition_fn, num_steps):
"""Check that the trajectories don't give NaNs."""
eq = spectral_equations.ForcedBurgersEquation(
viscosity=viscosity, grid=grid)
step_fn = time_stepping.crank_nicolson_rk2(eq, time_step)
step_fn = cfd.funcutils.repeated(step_fn, num_steps)
uhat0 = initial_condition_fn(grid)
t0 = 0.0
uhat1, _ = step_fn((uhat0, t0))
self.assertFalse(jnp.isnan(uhat1).any())
def test_nls_equation(self):
"""Check that trajectory matches Peregrine soliton analytic solution.
Soln from https://en.wikipedia.org/wiki/Peregrine_soliton,
however as we implement `psi_t = -i psi_xx/8 - i|psi|^2 psi/2`
rather than `psi_t = +i psi_xx/2 -+i|psi|^2 psi` from the wiki,
the solution needs to be rescaled and conjugated.
"""
def solve_nls(u0, t_final=1., max_samples=1024, dt=1e-2, extent=500):
N = len(u0) # pylint: disable=invalid-name
grid = grids.Grid((N,), domain=((-extent / 2, extent / 2),))
xs, = grid.axes(offset=(0,))
eq = spectral_equations.NonlinearSchrodinger(grid=grid)
stepfn = time_stepping.crank_nicolson_rk4(eq, dt)
uhat0 = jnp.fft.fft(u0)
numsteps = int(t_final / dt)
ds_period = max(numsteps // max_samples, 1)
multistepfn = jax.jit(cfd.funcutils.repeated(stepfn, ds_period))
_, uhat_traj = cfd.funcutils.trajectory(multistepfn, max_samples)(uhat0)
u_traj = jax.vmap(jnp.fft.ifft)(uhat_traj)
timesteps = (1 + jnp.arange(min(max_samples, numsteps))) * dt * ds_period
return u_traj, xs, timesteps
L = 40 * jnp.pi # pylint: disable=invalid-name
grid = grids.Grid((2**10,), domain=((-L / 2, L / 2),))
dt = 3e-4
tau = 8
T = tau * 2 # pylint: disable=invalid-name
xs, = grid.axes(offset=(0,))
zs = xs * jnp.sqrt(2)
u0 = (4 * zs**2 - 3) / (1 + 4 * zs**2)
soln, x_ds, t_ds = solve_nls(u0, T, dt=dt, extent=L)
z_ds = x_ds * jnp.sqrt(2)
tau_ds = t_ds / 2
gt_soln = 1 - 4 * (1 +
2j * tau_ds[:, None]) / (1 + 4 *
(z_ds**2 + tau_ds[:, None]**2))
gt_soln = jnp.conj(gt_soln * jnp.exp(1j * tau_ds[:, None]))
self.assertLess(jnp.abs(soln - gt_soln).mean(), 1e-3)
class EquationsTest2D(test_util.TestCase):
@parameterized.named_parameters(ALL_TIME_STEPPERS)
def test_forced_turbulence(self, time_stepper):
"""Check that forced turbulence runs for 100 steps without blowing up."""
grid = grids.Grid((128, 128), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
v0 = cfd.initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(42), grid, 7, 4)
vorticity0 = cfd.finite_differences.curl_2d(v0).data
vorticity_hat0 = jnp.fft.rfftn(vorticity0)
viscosity = 1e-3
dt = 1e-5
step_fn = time_stepper(
spectral_equations.NavierStokes2D(
viscosity,
grid,
forcing_fn=forcings.kolmogorov_forcing,
drag=0.1), dt)
trajectory_fn = cfd.funcutils.trajectory(step_fn, 100)
_, trajectory = trajectory_fn(vorticity_hat0)
self.assertTrue(jnp.all(~jnp.isnan(trajectory)))
def test_viscosity(self):
"""Test that higher viscosity results in faster decay."""
grid = grids.Grid((128, 128), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
v0 = cfd.initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(42), grid, 7, 4)
vorticity0 = cfd.finite_differences.curl_2d(v0).data
vorticity_hat0 = jnp.fft.rfftn(vorticity0)
norms = []
for viscosity in [1e-3, 1e-1, 1]:
dt = cfd.equations.stable_time_step(
7, .5, viscosity, grid, implicit_diffusion=True)
step_fn = time_stepping.crank_nicolson_rk4(
spectral_equations.NavierStokes2D(
viscosity,
grid,
forcing_fn=forcings.kolmogorov_forcing,
drag=0.1), dt)
trajectory_fn = cfd.funcutils.trajectory(step_fn, 100)
_, trajectory = trajectory_fn(vorticity_hat0)
norms.append(jnp.linalg.norm(trajectory))
# higher viscosity means that you get to zero faster.
self.assertLess(norms[2], norms[1])
self.assertLess(norms[1], norms[0])
@parameterized.named_parameters(
dict(
testcase_name='_TaylorGreen_SemiImplicitNavierStokes',
problem=cfd.validation_problems.TaylorGreen(
shape=(1024, 1024), density=1., viscosity=1e-3),
equation=spectral_equations.NavierStokes2D,
time_stepper=time_stepping.crank_nicolson_rk4,
max_courant_number=.5,
time=.11,
atol=1e-3),)
def test_accuracy(self, problem, equation, time_stepper, max_courant_number,
time, atol):
"""Check numerical accuracy of our solvers to known analytic solutions."""
# This closely emulates a test in jax cfd:
# https://source.corp.google.com/piper///depot/google3/third_party/py/jax_cfd/base/validation_test.py;l=113
v0 = problem.velocity(0.)
vorticity = finite_differences.curl_2d(v0).data
dt = cfd.equations.stable_time_step(
7,
max_courant_number,
problem.viscosity,
problem.grid,
implicit_diffusion=True)
steps = int(jnp.ceil(time / dt))
step_fn = time_stepper(
equation(
viscosity=problem.viscosity,
grid=problem.grid,
forcing_fn=None,
drag=0), dt)
_, vorticity_computed = cfd.funcutils.trajectory(
cfd.funcutils.repeated(step_fn, steps), 1)(
jnp.fft.rfftn(vorticity))
v = problem.velocity(time)
vorticity_analytic = finite_differences.curl_2d(v).data
self.assertAllClose(
jnp.fft.irfftn(vorticity_computed[0]), vorticity_analytic, atol=atol)
@parameterized.named_parameters(
dict(
testcase_name='_decaying_turbulence',
viscosity=1e-2,
cfl_safety_factor=.1,
max_velocity=2.0,
peak_wavenumber=4,
seed=0,
density=1.0,
n_steps=500,
grid_size=512,
is_forced=False,
atol=0.09,
),
dict(
testcase_name='_forced_turbulence',
viscosity=1e-2,
cfl_safety_factor=.1,
max_velocity=2.0,
peak_wavenumber=4,
seed=0,
density=1.0,
n_steps=150,
grid_size=512,
is_forced=True,
atol=0.07,
),
)
def test_compare_to_finite_difference_method(self, viscosity,
cfl_safety_factor, max_velocity,
peak_wavenumber, seed, density,
n_steps, grid_size,
is_forced,
atol):
"""Compare spectral to finite volume."""
grid = cfd.grids.Grid((grid_size, grid_size),
domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
# Construct a random initial velocity.
v0 = cfd.initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(seed), grid, max_velocity)
# Choose a time step.
dt = cfd.equations.stable_time_step(max_velocity, cfl_safety_factor,
viscosity, grid)
if is_forced:
fvm_forcing = forcings.simple_turbulence_forcing(
grid,
constant_magnitude=1,
constant_wavenumber=4,
linear_coefficient=-0.1,
forcing_type='kolmogorov')
eq = spectral_equations.ForcedNavierStokes2D(
viscosity, grid, smooth=True)
else:
fvm_forcing = None
eq = spectral_equations.NavierStokes2D(
viscosity, grid, smooth=True, drag=0, forcing_fn=None)
# use `repeated` since we only compare the final state
fvm_rollout_fn = jax.jit(
cfd.funcutils.repeated(
cfd.equations.semi_implicit_navier_stokes(
density=density,
viscosity=viscosity,
dt=dt,
grid=grid,
forcing=fvm_forcing),
steps=n_steps))
v = fvm_rollout_fn(v0)
final_state_fvm = cfd.finite_differences.curl_2d(v).data
spectral_rollout_fn = jax.jit(
cfd.funcutils.repeated(time_stepping.crank_nicolson_rk4(eq, dt),
steps=n_steps))
final_state_spectral = jnp.fft.irfftn(
spectral_rollout_fn(
jnp.fft.rfftn(
roll(cfd.finite_differences.curl_2d(v0).data, (1, 1)))))
self.assertAllClose(
final_state_fvm, roll(final_state_spectral, (-1, -1)), atol=atol)
if __name__ == '__main__':
absltest.main()
"""Forcing functions for spectral equations."""
import jax
import jax.numpy as jnp
from jax_cfd.base import grids
def random_forcing_module(grid: grids.Grid,
seed: int = 0,
n: int = 20,
offset=(0,)):
"""Implements the forcing described in Bar-Sinai et al. [*].
Args:
grid: grid to use for the x-axis
seed: random seed for computing the random waves
n: number of random waves to use
offset: offset for the x-axis. Defaults to (0,) for the Fourier basis.
Returns:
Time dependent forcing function.
[*] Bar-Sinai, Yohai, Stephan Hoyer, Jason Hickey, and Michael P. Brenner.
"Learning data-driven discretizations for partial differential equations."
Proceedings of the National Academy of Sciences 116, no. 31 (2019):
15344-15349.
"""
key = jax.random.PRNGKey(seed)
ks = jnp.array([3, 4, 5, 6])
key, subkey = jax.random.split(key)
kx = jax.random.choice(subkey, ks, shape=(n,))
key, subkey = jax.random.split(key)
amplitude = jax.random.uniform(subkey, minval=-0.5, maxval=0.5, shape=(n,))
key, subkey = jax.random.split(key)
omega = jax.random.uniform(subkey, minval=-0.4, maxval=0.4, shape=(n,))
key, subkey = jax.random.split(key)
phi = jax.random.uniform(subkey, minval=0, maxval=2 * jnp.pi, shape=(n,))
xs, = grid.axes(offset=offset)
def forcing_fn(t):
@jnp.vectorize
def eval_force(x):
f = amplitude * jnp.sin(omega * t - x * kx + phi)
return f.sum()
return eval_force(xs)
return forcing_fn
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implicit-explicit time stepping routines for ODEs."""
import dataclasses
from typing import Callable, Sequence, TypeVar
import tree_math
PyTreeState = TypeVar("PyTreeState")
TimeStepFn = Callable[[PyTreeState], PyTreeState]
class ImplicitExplicitODE:
"""Describes a set of ODEs with implicit & explicit terms.
The equation is given by:
∂x/∂t = explicit_terms(x) + implicit_terms(x)
`explicit_terms(x)` includes terms that should use explicit time-stepping and
`implicit_terms(x)` includes terms that should be modeled implicitly.
Typically the explicit terms are non-linear and the implicit terms are linear.
This simplifies solves but isn't strictly necessary.
"""
def explicit_terms(self, state: PyTreeState) -> PyTreeState:
"""Evaluates explicit terms in the ODE."""
raise NotImplementedError
def implicit_terms(self, state: PyTreeState) -> PyTreeState:
"""Evaluates implicit terms in the ODE."""
raise NotImplementedError
def implicit_solve(
self, state: PyTreeState, step_size: float,
) -> PyTreeState:
"""Solves `y - step_size * implicit_terms(y) = x` for y."""
raise NotImplementedError
def backward_forward_euler(
equation: ImplicitExplicitODE, time_step: float,
) -> TimeStepFn:
"""Time stepping via forward and backward Euler methods.
This method is first order accurate.
Args:
equation: equation to solve.
time_step: time step.
Returns:
Function that performs a time step.
"""
# pylint: disable=invalid-name
dt = time_step
F = tree_math.unwrap(equation.explicit_terms)
G_inv = tree_math.unwrap(equation.implicit_solve, vector_argnums=0)
@tree_math.wrap
def step_fn(u0):
g = u0 + dt * F(u0)
u1 = G_inv(g, dt)
return u1
return step_fn
def crank_nicolson_rk2(
equation: ImplicitExplicitODE, time_step: float,
) -> TimeStepFn:
"""Time stepping via Crank-Nicolson and 2nd order Runge-Kutta (Heun).
This method is second order accurate.
Args:
equation: equation to solve.
time_step: time step.
Returns:
Function that performs a time step.
Reference:
Chandler, G. J. & Kerswell, R. R. Invariant recurrent solutions embedded in
a turbulent two-dimensional Kolmogorov flow. J. Fluid Mech. 722, 554–595
(2013). https://doi.org/10.1017/jfm.2013.122 (Section 3)
"""
# pylint: disable=invalid-name
dt = time_step
F = tree_math.unwrap(equation.explicit_terms)
G = tree_math.unwrap(equation.implicit_terms)
G_inv = tree_math.unwrap(equation.implicit_solve, vector_argnums=0)
@tree_math.wrap
def step_fn(u0):
g = u0 + 0.5 * dt * G(u0)
h1 = F(u0)
u1 = G_inv(g + dt * h1, 0.5 * dt)
h2 = 0.5 * (F(u1) + h1)
u2 = G_inv(g + dt * h2, 0.5 * dt)
return u2
return step_fn
def low_storage_runge_kutta_crank_nicolson(
alphas: Sequence[float],
betas: Sequence[float],
gammas: Sequence[float],
equation: ImplicitExplicitODE,
time_step: float,
) -> TimeStepFn:
"""Time stepping via "low-storage" Runge-Kutta and Crank-Nicolson steps.
These scheme are second order accurate for the implicit terms, but potentially
higher order accurate for the explicit terms. This seems to be a favorable
tradeoff when the explicit terms dominate, e.g., for modeling turbulent
fluids.
Per Canuto: "[these methods] have been widely used for the time-discretization
in applications of spectral methods."
Args:
alphas: alpha coefficients.
betas: beta coefficients.
gammas: gamma coefficients.
equation: equation to solve.
time_step: time step.
Returns:
Function that performs a time step.
Reference:
Canuto, C., Yousuff Hussaini, M., Quarteroni, A. & Zang, T. A.
Spectral Methods: Evolution to Complex Geometries and Applications to
Fluid Dynamics. (Springer Berlin Heidelberg, 2007).
https://doi.org/10.1007/978-3-540-30728-0 (Appendix D.3)
"""
# pylint: disable=invalid-name,non-ascii-name
α = alphas
β = betas
γ = gammas
dt = time_step
F = tree_math.unwrap(equation.explicit_terms)
G = tree_math.unwrap(equation.implicit_terms)
G_inv = tree_math.unwrap(equation.implicit_solve, vector_argnums=0)
if len(alphas) - 1 != len(betas) != len(gammas):
raise ValueError("number of RK coefficients does not match")
@tree_math.wrap
def step_fn(u):
h = 0
for k in range(len(β)):
h = F(u) + β[k] * h
µ = 0.5 * dt * (α[k + 1] - α[k])
u = G_inv(u + γ[k] * dt * h + µ * G(u), µ)
return u
return step_fn
def crank_nicolson_rk3(
equation: ImplicitExplicitODE, time_step: float,
) -> TimeStepFn:
"""Time stepping via Crank-Nicolson and RK3 ("Williamson")."""
return low_storage_runge_kutta_crank_nicolson(
alphas=[0, 1/3, 3/4, 1],
betas=[0, -5/9, -153/128],
gammas=[1/3, 15/16, 8/15],
equation=equation,
time_step=time_step,
)
def crank_nicolson_rk4(
equation: ImplicitExplicitODE, time_step: float,
) -> TimeStepFn:
"""Time stepping via Crank-Nicolson and RK4 ("Carpenter-Kennedy")."""
# pylint: disable=line-too-long
return low_storage_runge_kutta_crank_nicolson(
alphas=[0, 0.1496590219993, 0.3704009573644, 0.6222557631345, 0.9582821306748, 1],
betas=[0, -0.4178904745, -1.192151694643, -1.697784692471, -1.514183444257],
gammas=[0.1496590219993, 0.3792103129999, 0.8229550293869, 0.6994504559488, 0.1530572479681],
equation=equation,
time_step=time_step,
)
@dataclasses.dataclass
class ImExButcherTableau:
"""Butcher Tableau for implicit-explicit Runge-Kutta methods."""
a_ex: Sequence[Sequence[float]]
a_im: Sequence[Sequence[float]]
b_ex: Sequence[float]
b_im: Sequence[float]
def __post_init__(self):
if len({len(self.a_ex) + 1,
len(self.a_im) + 1,
len(self.b_ex),
len(self.b_im)}) > 1:
raise ValueError("inconsistent Butcher tableau")
def imex_runge_kutta(
tableau: ImExButcherTableau,
equation: ImplicitExplicitODE,
time_step: float,
) -> TimeStepFn:
"""Time stepping with Implicit-Explicit Runge-Kutta."""
# pylint: disable=invalid-name
dt = time_step
F = tree_math.unwrap(equation.explicit_terms)
G = tree_math.unwrap(equation.implicit_terms)
G_inv = tree_math.unwrap(equation.implicit_solve, vector_argnums=0)
a_ex = tableau.a_ex
a_im = tableau.a_im
b_ex = tableau.b_ex
b_im = tableau.b_im
num_steps = len(b_ex)
@tree_math.wrap
def step_fn(y0):
f = [None] * num_steps
g = [None] * num_steps
f[0] = F(y0)
g[0] = G(y0)
for i in range(1, num_steps):
ex_terms = dt * sum(a_ex[i-1][j] * f[j] for j in range(i) if a_ex[i-1][j])
im_terms = dt * sum(a_im[i-1][j] * g[j] for j in range(i) if a_im[i-1][j])
Y_star = y0 + ex_terms + im_terms
Y = G_inv(Y_star, dt * a_im[i-1][i])
if any(a_ex[j][i] for j in range(i, num_steps - 1)) or b_ex[i]:
f[i] = F(Y)
if any(a_im[j][i] for j in range(i, num_steps - 1)) or b_im[i]:
g[i] = G(Y)
ex_terms = dt * sum(b_ex[j] * f[j] for j in range(num_steps) if b_ex[j])
im_terms = dt * sum(b_im[j] * g[j] for j in range(num_steps) if b_im[j])
y_next = y0 + ex_terms + im_terms
return y_next
return step_fn
def imex_rk_sil3(
equation: ImplicitExplicitODE, time_step: float,
) -> TimeStepFn:
"""Time stepping with the SIL3 implicit-explicit RK scheme.
This method is second-order accurate for the implicit terms and third-order
accurate for the explicit terms.
Args:
equation: equation to solve.
time_step: time step.
Returns:
Function that performs a time step.
Reference:
Whitaker, J. S. & Kar, S. K. Implicit-Explicit Runge-Kutta Methods for
Fast-Slow Wave Problems. Monthly Weather Review vol. 141 3426-3434 (2013)
http://dx.doi.org/10.1175/mwr-d-13-00132.1
"""
return imex_runge_kutta(
tableau=ImExButcherTableau(
a_ex=[[1/3], [1/6, 1/2], [1/2, -1/2, 1]],
a_im=[[1/6, 1/6], [1/3, 0, 1/3], [3/8, 0, 3/8, 1/4]],
b_ex=[1/2, -1/2, 1, 0],
b_im=[3/8, 0, 3/8, 1/4],
),
equation=equation,
time_step=time_step,
)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for time_stepping."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import tree_util
from jax import config
import jax.numpy as jnp
from jax_cfd.base import funcutils
from jax_cfd.spectral import time_stepping
import numpy as np
def harmonic_oscillator(x0, t):
theta = jnp.arctan(x0[0] / x0[1])
r = jnp.linalg.norm(x0, ord=2, axis=0)
return r * jnp.stack([jnp.sin(t + theta), jnp.cos(t + theta)])
class CustomODE(time_stepping.ImplicitExplicitODE):
def __init__(self, explicit_terms, implicit_terms, implicit_solve):
self.explicit_terms = explicit_terms
self.implicit_terms = implicit_terms
self.implicit_solve = implicit_solve
ALL_TEST_PROBLEMS = [
# x(t) = np.ones(10)
dict(testcase_name='_zero_derivative',
explicit_terms=lambda x: 0 * x,
implicit_terms=lambda x: 0 * x,
implicit_solve=lambda x, eta: x,
dt=1e-2,
inner_steps=10,
outer_steps=5,
initial_state=np.ones(10),
closed_form=lambda x0, t: x0,
tolerances=[1e-12] * 5),
# x(t) = 5 * t * np.ones(3)
dict(testcase_name='_constant_derivative',
explicit_terms=lambda x: 5 * jnp.ones_like(x),
implicit_terms=lambda x: 0 * x,
implicit_solve=lambda x, eta: x,
dt=1e-2,
inner_steps=10,
outer_steps=5,
initial_state=np.ones(3),
closed_form=lambda x0, t: x0 + 5 * t,
tolerances=[1e-12] * 5),
# x(t) = np.arange(3) * np.exp(t)
# Uses explicit terms only.
dict(testcase_name='_linear_derivative_explicit',
explicit_terms=lambda x: x,
implicit_terms=lambda x: 0 * x,
implicit_solve=lambda x, eta: x,
dt=1e-2,
inner_steps=20,
outer_steps=5,
initial_state=np.arange(3.0),
closed_form=lambda x0, t: np.arange(3) * jnp.exp(t),
tolerances=[5e-2, 1e-4, 1e-6, 1e-9, 1e-6]),
# x(t) = np.arange(3) * np.exp(t)
# Uses implicit terms only.
dict(testcase_name='_linear_derivative_implicit',
explicit_terms=lambda x: 0 * x,
implicit_terms=lambda x: x,
implicit_solve=lambda x, eta: x / (1 - eta),
dt=1e-2,
inner_steps=20,
outer_steps=5,
initial_state=np.arange(3.0),
closed_form=lambda x0, t: np.arange(3) * jnp.exp(t),
tolerances=[5e-2, 5e-5, 1e-5, 1e-5, 3e-5]),
# x(t) = np.arange(3) * np.exp(t)
# Splits the equation into an implicit and explicit term.
dict(testcase_name='_linear_derivative_semi_implicit',
explicit_terms=lambda x: x / 2,
implicit_terms=lambda x: x / 2,
implicit_solve=lambda x, eta: x / (1 - eta / 2),
dt=1e-2,
inner_steps=20,
outer_steps=5,
initial_state=np.arange(3) * np.exp(0),
closed_form=lambda x0, t: np.arange(3.0) * jnp.exp(t),
tolerances=[1e-4, 2e-5, 2e-6, 1e-6, 2e-5]),
dict(testcase_name='_harmonic_oscillator_explicit',
explicit_terms=lambda x: jnp.stack([x[1], -x[0]]),
implicit_terms=jnp.zeros_like,
implicit_solve=lambda x, eta: x,
dt=1e-2,
inner_steps=20,
outer_steps=5,
initial_state=np.ones(2),
closed_form=harmonic_oscillator,
tolerances=[1e-2, 3e-5, 6e-8, 5e-11, 6e-8]),
dict(testcase_name='_harmonic_oscillator_implicit',
explicit_terms=jnp.zeros_like,
implicit_terms=lambda x: jnp.stack([x[1], -x[0]]),
implicit_solve=lambda x, eta: jnp.stack( # pylint: disable=g-long-lambda
[x[0] + eta * x[1], x[1] - eta * x[0]]) / (1 + eta ** 2),
dt=1e-2,
inner_steps=20,
outer_steps=5,
initial_state=np.ones(2),
closed_form=harmonic_oscillator,
tolerances=[1e-2, 2e-5, 2e-6, 1e-6, 6e-6]),
]
ALL_TIME_STEPPERS = [
time_stepping.backward_forward_euler,
time_stepping.crank_nicolson_rk2,
time_stepping.crank_nicolson_rk3,
time_stepping.crank_nicolson_rk4,
time_stepping.imex_rk_sil3,
]
class TimeSteppingTest(parameterized.TestCase):
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_implicit_solve(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
"""Tests that time integration is accurate for a range of test cases."""
del dt, explicit_terms, inner_steps, outer_steps, closed_form # unused
del tolerances # unused
# Verifies that `implicit_solve` solves (y - eta * F(y)) = x
# This does not test the integrator, but rather verifies that the test
# case is valid.
eta = 0.3
solved_state = implicit_solve(initial_state, eta)
reconstructed_state = solved_state - eta * implicit_terms(solved_state)
np.testing.assert_allclose(reconstructed_state, initial_state)
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
def test_pytree_state(self):
equation = CustomODE(
explicit_terms=lambda x: tree_util.tree_map(jnp.zeros_like, x),
implicit_terms=lambda x: tree_util.tree_map(jnp.zeros_like, x),
implicit_solve=lambda x, eta: x,
)
u0 = {'x': 1.0, 'y': 1.0}
for time_stepper in ALL_TIME_STEPPERS:
with self.subTest(time_stepper.__name__):
u1 = time_stepper(equation, 1.0)(u0)
self.assertEqual(u0, u1)
if __name__ == '__main__':
config.update('jax_enable_x64', True)
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common types that are used throughout the spectral codes."""
from typing import Callable, Union
import jax
import numpy as np
Array = Union[np.ndarray, jax.Array]
StepFn = Callable[[Array], Array]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for building pseudospectral methods."""
from typing import Callable, Tuple
import jax.numpy as jnp
from jax_cfd.base import grids
from jax_cfd.spectral import types as spectral_types
def truncated_rfft(u: spectral_types.Array) -> spectral_types.Array:
"""Applies the 2/3 rule by truncating higher Fourier modes.
Args:
u: the real-space representation of the input signal
Returns:
Downsampled version of `u` in rfft-space.
"""
uhat = jnp.fft.rfft(u)
k, = uhat.shape
final_size = int(2 / 3 * k) + 1
return 2 / 3 * uhat[:final_size]
def padded_irfft(uhat: spectral_types.Array) -> spectral_types.Array:
"""Applies the 3/2 rule by padding with zeros.
Args:
uhat: the rfft representation of a signal
Returns:
An upsampled signal in real space which 3/2 times larger than the input
signal `uhat`.
"""
n, = uhat.shape
final_shape = int(3 / 2 * n)
smoothed = jnp.pad(uhat, (0, final_shape - n))
assert smoothed.shape == (final_shape,), "incorrect padded shape"
return 1.5 * jnp.fft.irfft(smoothed)
def truncated_fft_2x(u: spectral_types.Array) -> spectral_types.Array:
"""Applies the 1/2 rule to complex u by truncating higher Fourier modes.
Args:
u: the (complex) input signal
Returns:
Downsampled version of `u` in fft-space.
"""
uhat = jnp.fft.fftshift(jnp.fft.fft(u))
k, = uhat.shape
final_size = (k + 1) // 2
return jnp.fft.ifftshift(uhat[final_size // 2:(-final_size + 1) // 2]) / 2
def padded_ifft_2x(uhat: spectral_types.Array) -> spectral_types.Array:
"""Applies the 2x rule to complex F[u] by padding higher frequencies.
Pads with zeros in the Fourier domain before performing the ifft
(effectively performing 2x interpolation in the spatial domain)
Args:
uhat: the fft representation of signal
Returns:
An upsampled signal in real space interpolated to 2x more points than
`jax.fft.ifft(uhat)`.
"""
n, = uhat.shape
final_size = n + 2 * (n // 2)
added = n // 2
smoothed = jnp.pad(jnp.fft.fftshift(uhat), (added, added))
assert smoothed.shape == (final_size,), "incorrect padded shape"
return 2 * jnp.fft.ifft(jnp.fft.ifftshift(smoothed))
def circular_filter_2d(grid: grids.Grid) -> spectral_types.Array:
"""Circular filter which roughly matches the 2/3 rule but is smoother.
Follows the technique described in Equation 1 of [1]. We use a different value
for alpha as used by pyqg [2].
Args:
grid: the grid to filter over
Returns:
Filter mask
Reference:
[1] Arbic, Brian K., and Glenn R. Flierl. "Coherent vortices and kinetic
energy ribbons in asymptotic, quasi two-dimensional f-plane turbulence."
Physics of Fluids 15, no. 8 (2003): 2177-2189.
https://doi.org/10.1063/1.1582183
[2] Ryan Abernathey, rochanotes, Malte Jansen, Francis J. Poulin, Navid C.
Constantinou, Dhruv Balwada, Anirban Sinha, Mike Bueti, James Penn,
Christopher L. Pitt Wolfe, & Bia Villas Boas. (2019). pyqg/pyqg: v0.3.0
(v0.3.0). Zenodo. https://doi.org/10.5281/zenodo.3551326.
See:
https://github.com/pyqg/pyqg/blob/02e8e713660d6b2043410f2fef6a186a7cb225a6/pyqg/model.py#L136
"""
kx, ky = grid.rfft_mesh()
max_k = ky[-1, -1]
circle = jnp.sqrt(kx**2 + ky**2)
cphi = 0.65 * max_k
filterfac = 23.6
filter_ = jnp.exp(-filterfac * (circle - cphi)**4.)
filter_ = jnp.where(circle <= cphi, jnp.ones_like(filter_), filter_)
return filter_
def brick_wall_filter_2d(grid: grids.Grid):
"""Implements the 2/3 rule."""
n, m = grid.shape
filter_ = jnp.zeros((n, m // 2 + 1))
filter_ = filter_.at[:int(2 / 3 * n) // 2, :int(2 / 3 * (m // 2 + 1))].set(1)
filter_ = filter_.at[-int(2 / 3 * n) // 2:, :int(2 / 3 * (m // 2 + 1))].set(1)
return filter_
def exponential_filter(signal, alpha=1e-6, order=2):
"""Apply a low-pass smoothing filter to remove noise from 2D signal."""
# Based on:
# 1. Gottlieb and Hesthaven (2001), "Spectral methods for hyperbolic problems"
# https://doi.org/10.1016/S0377-0427(00)00510-0
# 2. Also, see https://arxiv.org/pdf/math/0701337.pdf --- Eq. 5
# TODO(dresdner) save a few ffts by factoring out the actual filter, sigma.
alpha = -jnp.log(alpha)
n, _ = signal.shape # TODO(dresdner) check square / handle 1D case
kx, ky = jnp.fft.fftfreq(n), jnp.fft.rfftfreq(n)
kx, ky = jnp.meshgrid(kx, ky, indexing="ij")
eta = jnp.sqrt(kx**2 + ky**2)
sigma = jnp.exp(-alpha * eta**(2 * order))
return jnp.fft.irfft2(sigma * jnp.fft.rfft2(signal))
def vorticity_to_velocity(
grid: grids.Grid
) -> Callable[[spectral_types.Array], Tuple[spectral_types.Array,
spectral_types.Array]]:
"""Constructs a function for converting vorticity to velocity, both in Fourier domain.
Solves for the stream function and then uses the stream function to compute
the velocity. This is the standard approach. A quick sketch can be found in
[1].
Args:
grid: the grid underlying the vorticity field.
Returns:
A function that takes a vorticity (rfftn) and returns a velocity vector
field.
Reference:
[1] Z. Yin, H.J.H. Clercx, D.C. Montgomery, An easily implemented task-based
parallel scheme for the Fourier pseudospectral solver applied to 2D
Navier–Stokes turbulence, Computers & Fluids, Volume 33, Issue 4, 2004,
Pages 509-520, ISSN 0045-7930,
https://doi.org/10.1016/j.compfluid.2003.06.003.
"""
kx, ky = grid.rfft_mesh()
two_pi_i = 2 * jnp.pi * 1j
laplace = two_pi_i ** 2 * (abs(kx)**2 + abs(ky)**2)
laplace = laplace.at[0, 0].set(1) # pytype: disable=attribute-error # jnp-type
def ret(vorticity_hat):
psi_hat = -1 / laplace * vorticity_hat
vxhat = two_pi_i * ky * psi_hat
vyhat = -two_pi_i * kx * psi_hat
return vxhat, vyhat
return ret
def filter_step(step_fn: spectral_types.StepFn, filter_: spectral_types.Array):
"""Returns a filtered version of the step_fn."""
def new_step_fn(state):
return filter_ * step_fn(state)
return new_step_fn
def spectral_curl_2d(mesh, velocity_hat):
"""Computes the 2D curl in the Fourier basis."""
kx, ky = mesh
uhat, vhat = velocity_hat
return 2j * jnp.pi * (vhat * kx - uhat * ky)
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for utils."""
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
from jax_cfd.base import finite_differences
from jax_cfd.base import grids
from jax_cfd.base import initial_conditions
from jax_cfd.base import interpolation
from jax_cfd.base import test_util
from jax_cfd.spectral import utils
class ThreeOverTwoRuleTest1D(test_util.TestCase):
def test_rfft_padding_and_truncation(self):
# This test is essentially recreating Figure 4 of go/uecker
n = 8
grid = grids.Grid((n,), domain=((0, 2 * jnp.pi),))
xs, = grid.axes()
u = jnp.cos(3 * xs)
uhat = jnp.fft.rfft(u)
k, = uhat.shape
uhat_squared = utils.truncated_rfft(utils.padded_irfft(uhat)**2)
assert len(uhat_squared) == k
u_squared = jnp.fft.irfft(uhat_squared)
self.assertAllClose(.5, u_squared, atol=1e-4)
class NavierStokesHelpersTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_seed=0', seed=0),
dict(testcase_name='_seed=1', seed=1))
def test_construct_circular_filter(self, seed):
grid = grids.Grid((8, 8), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
mask = utils.circular_filter_2d(grid)
# check that masking decreasing the l2-norm.
key = jax.random.PRNGKey(seed)
signal = jax.random.normal(key, (8, 8))
signal_hat = jnp.fft.rfftn(signal)
self.assertLess(
jnp.linalg.norm(mask * signal_hat), jnp.linalg.norm(signal_hat))
@parameterized.named_parameters(
dict(testcase_name='_atol=1e-2',
atol=1e-2,
grid=grids.Grid((128, 128),
domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))))
def test_vorticity_to_velocity_round_trip(self, atol, grid):
"""Check that velocity solve and curl 2d are inverses."""
u, v = initial_conditions.filtered_velocity_field(
jax.random.PRNGKey(42), grid, maximum_velocity=7, peak_wavenumber=1)
velocity_solve = utils.vorticity_to_velocity(grid)
vorticity = finite_differences.curl_2d((u, v))
vorticity_hat = jnp.fft.rfftn(vorticity.data)
uhat, vhat = velocity_solve(vorticity_hat)
self.assertAllClose(
jnp.fft.irfftn(uhat),
interpolation.linear(u, vorticity.offset).data,
atol=atol)
self.assertAllClose(
jnp.fft.irfftn(vhat),
finite_differences.interpolation.linear(v, vorticity.offset).data,
atol=atol)
if __name__ == '__main__':
absltest.main()
# 模型唯一标识
modelCode = 647
# 模型名称
modelName=jax-cfd
# 模型描述
modelDescription=机器学习、自动微分和硬件加速器在计算流体动力学中潜在应用的实验研究项目。
# 应用场景
appScenario=推理,流体动力学,化工,气象,能源
# 框架类型
frameType=jax
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment