Commit 5e31fa1f authored by mashun1's avatar mashun1
Browse files

jax-cfd

parents
Pipeline #1015 canceled with stages
"""Models for advection and convection components."""
import functools
from typing import Callable, Optional
import gin
from jax_cfd.base import advection
from jax_cfd.base import grids
from jax_cfd.ml import interpolations
from jax_cfd.ml import physics_specifications
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationModule = interpolations.InterpolationModule
AdvectFn = Callable[[GridVariable, GridVariableVector, float], GridArray]
AdvectionModule = Callable[..., AdvectFn]
ConvectFn = Callable[[GridVariableVector], GridArrayVector]
ConvectionModule = Callable[..., ConvectFn]
@gin.register
def modular_advection(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
c_interpolation_module: InterpolationModule = interpolations.upwind,
u_interpolation_module: InterpolationModule = interpolations.linear,
**kwargs
) -> AdvectFn:
"""Modular advection module based on `advection_diffusion.advect_general`."""
c_interpolate_fn = c_interpolation_module(grid, dt, physics_specs, **kwargs)
u_interpolate_fn = u_interpolation_module(grid, dt, physics_specs, **kwargs)
def advect(
c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None
) -> GridArray:
return advection.advect_general(
c, v, u_interpolate_fn, c_interpolate_fn, dt)
return advect
@gin.register
def modular_self_advection(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
interpolation_module: InterpolationModule,
**kwargs
) -> AdvectFn:
"""Modular self advection using a single interpolation module."""
# TODO(jamieas): Replace this entire function once
# `single_tower_navier_stokes` is in place.
interpolate_fn = interpolation_module(grid, dt, physics_specs, **kwargs)
c_interpolate_fn = functools.partial(interpolate_fn, tag='c')
u_interpolate_fn = functools.partial(interpolate_fn, tag='u')
def advect(
c: GridVariable,
v: GridVariableVector,
dt: Optional[float] = None
) -> GridArray:
return advection.advect_general(
c, v, u_interpolate_fn, c_interpolate_fn, dt)
return advect
@gin.register
def self_advection(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
advection_module: AdvectionModule = modular_advection,
**kwargs
) -> ConvectFn:
"""Convection module based on simultaneous self-advection of velocities."""
advect_fn = advection_module(grid, dt, physics_specs, **kwargs)
def convect(v: GridVariableVector) -> GridArrayVector:
return tuple(advect_fn(u, v, dt) for u in v)
return convect
"""Decoder modules that help interfacing model states with output data.
All decoder modules generate a function that given an specific model state
return the observable data of the same structure as provided to the Encoder.
Decoders can be either fixed functions, decorators, or learned modules.
"""
from typing import Any, Callable, Optional
import gin
import haiku as hk
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import towers
from jax_cfd.spectral import utils as spectral_utils
DecodeFn = Callable[[Any], Any] # maps model state to data time slice.
DecoderModule = Callable[..., DecodeFn] # generate DecodeFn closed over args.
TowerFactory = towers.TowerFactory
@gin.register
def identity_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
"""Identity decoder module that returns model state as is."""
del grid, dt, physics_specs # unused.
def decode_fn(inputs):
return inputs
return decode_fn
# TODO(dkochkov) generalize this to arbitrary pytrees.
@gin.register
def aligned_array_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
"""Generates decoder that extracts data from GridVariables."""
del grid, dt, physics_specs # unused.
def decode_fn(inputs):
return tuple(x.data for x in inputs)
return decode_fn
@gin.register
def staggered_to_collocated_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
):
"""Decoder that interpolates from staggered to collocated grids."""
del dt, physics_specs # unused.
def decode_fn(inputs):
interp_inputs = [interpolation.linear(c, grid.cell_center) for c in inputs]
return tuple(x.data for x in interp_inputs)
return decode_fn
@gin.register
def channels_split_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
"""Generates decoder that splits channels into data tuples."""
del grid, dt, physics_specs # unused.
def decode_fn(inputs):
return array_utils.split_axis(inputs, -1)
return decode_fn
@gin.register
def latent_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
tower_factory: TowerFactory,
num_components: Optional[int] = None,
):
"""Generates trainable decoder that maps latent representation to data tuple.
Decoder first computes an array of outputs using network specified by a
`tower_factory` and then splits the channels into `num_components` components.
Args:
grid: grid representing spatial discritization of the system.
dt: time step to use for time evolution.
physics_specs: physical parameters of the simulation.
tower_factory: factory that produces trainable tower network module.
num_components: number of data tuples in the data representation of the
state. If None, assumes num_components == grid.ndims. Default is None.
Returns:
decode function that maps latent state `inputs` at given time to a tuple of
`num_components` data arrays representing the same state at the same time.
"""
split_channels_fn = channels_split_decoder(grid, dt, physics_specs)
def decode_fn(inputs):
num_channels = num_components or grid.ndim
decoder_tower = tower_factory(num_channels, grid.ndim, name='decoder')
return split_channels_fn(decoder_tower(inputs))
return hk.to_module(decode_fn)()
@gin.register
def aligned_latent_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
tower_factory: TowerFactory,
num_components: Optional[int] = None,
):
"""Latent decoder that decodes from aligned arrays."""
split_channels_fn = channels_split_decoder(grid, dt, physics_specs)
def decode_fn(inputs):
inputs = jnp.stack([x.data for x in inputs], axis=-1)
num_channels = num_components or grid.ndim
decoder_tower = tower_factory(num_channels, grid.ndim, name='decoder')
return split_channels_fn(decoder_tower(inputs))
return hk.to_module(decode_fn)()
@gin.register
def vorticity_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
"""Solves for velocity and converts into GridVariables."""
del dt, physics_specs # unused.
velocity_solve = spectral_utils.vorticity_to_velocity(grid)
def decode_fn(vorticity):
# TODO(dresdner) note the main difference is the input, which is in real space instead of vorticity space
vorticity = jnp.squeeze(vorticity, axis=-1) # remove channel dim
vorticity_hat = jnp.fft.rfft2(vorticity)
uhat, vhat = velocity_solve(vorticity_hat)
v = (jnp.fft.irfft2(uhat), jnp.fft.irfft2(vhat))
return v
return decode_fn
@gin.register
def spectral_vorticity_decoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
"""Solves for velocity and converts into GridVariables."""
del dt, physics_specs # unused.
velocity_solve = spectral_utils.vorticity_to_velocity(grid)
def decode_fn(vorticity_hat):
uhat, vhat = velocity_solve(vorticity_hat)
v = (jnp.fft.irfft2(uhat), jnp.fft.irfft2(vhat))
return v
return decode_fn
"""Models for diffusion components.
All modules are functions that return `DiffuseFn` or `DiffusionSolveFn` method.
The two types of diffusion modules should be used with corresponding explicit
and implicit navier-stokes solvers.
An example explicit diffusion module:
```python
def diffusion_module(dt, module_params, **kwargs):
pre_compute_values = f(dt, module_params)
def diffuse(c: grids.GridVariable, nu: float, grid: grids.Grid, dt: float):
# compute time derivative due to diffusion.
return dc_dt
return diffuse
```
"""
import functools
from typing import Callable, Optional
import gin
import haiku as hk
from jax_cfd.base import diffusion
from jax_cfd.base import grids
from jax_cfd.base import subgrid_models
from jax_cfd.ml import viscosities
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
DiffuseFn = Callable[[GridVariable, float], GridArray]
DiffusionSolveFn = Callable[[GridVariableVector, float, float],
GridVariableVector]
DiffuseModule = Callable[..., DiffuseFn]
DiffusionSolveModule = Callable[..., DiffusionSolveFn]
ViscosityModule = viscosities.ViscosityModule
# TODO(shoyer): stop deleting unrecognized **kwargs. This is really error-prone!
@gin.register(denylist=("grid", "dt", "physics_specs"))
def diffuse(grid, dt, physics_specs) -> DiffuseFn:
del grid, dt, physics_specs # unused.
return diffusion.diffuse
@gin.register(denylist=("grid", "dt", "physics_specs"))
def solve_fast_diag(
grid,
dt,
physics_specs,
implementation=None
) -> DiffusionSolveFn:
del grid, dt, physics_specs # unused.
return functools.partial(
diffusion.solve_fast_diag, implementation=implementation)
@gin.register(denylist=("grid", "dt", "physics_specs"))
def solve_cg(
grid,
dt,
physics_specs,
atol: float = 1e-5,
rtol: float = 1e-5,
maxiter: Optional[int] = 64,
) -> DiffusionSolveFn:
"""Returns conjugate gradient solve method."""
del grid, dt, physics_specs # unused.
return functools.partial(
diffusion.solve_cg, atol=atol, rtol=rtol, maxiter=maxiter)
@gin.register(denylist=("grid", "dt", "physics_specs"))
def implicit_evm_solve_with_diffusion(
grid,
dt,
physics_specs,
viscosity_module: ViscosityModule = viscosities.eddy_viscosity_model,
atol: float = 1e-5,
maxiter: Optional[int] = 64,
) -> DiffusionSolveFn:
"""Returns solve_diffusion method that also includes a viscosity model."""
evm_model = viscosity_module(grid, dt, physics_specs)
cg_kwargs = dict(atol=atol, maxiter=maxiter)
diffusion_solve = functools.partial(
subgrid_models.implicit_evm_solve_with_diffusion,
configured_evm_model=evm_model,
cg_kwargs=cg_kwargs)
return hk.to_module(diffusion_solve)(name="diffusion_solve")
"""Encoder modules that help interfacing input trajectories to model states.
All encoder modules generate a function that given an input trajectory infers
the final state of the physical system in the representation defined by the
Encoder. Encoders can be either fixed functions, decorators or learned modules.
The input state is expected to consist of arrays with `time` as a leading axis.
"""
from typing import Any, Callable, Optional, Tuple
import gin
import haiku as hk
import jax
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import towers
EncodeFn = Callable[[Any], Any] # maps input trajectory to final model state.
EncoderModule = Callable[..., EncodeFn] # generate EncodeFn closed over args.
TowerFactory = towers.TowerFactory
@gin.register
def aligned_array_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
"""Generates encoder that wraps last data slice as GridVariables."""
del dt # unused.
if hasattr(physics_specs, 'combo_offsets'):
data_offsets = physics_specs.combo_offsets()
else:
data_offsets = data_offsets or grid.cell_faces
slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)
def encode_fn(inputs):
if hasattr(physics_specs, 'combo_boundaries'):
bcs = physics_specs.combo_boundaries()
else:
bcs = tuple(
boundaries.periodic_boundary_conditions(grid.ndim)
for _ in range(len(inputs)))
return tuple(
bc.impose_bc(grids.GridArray(slice_last_fn(x), offset, grid))
for x, offset, bc in zip(inputs, data_offsets, bcs))
return encode_fn
@gin.register
def collocated_to_staggered_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
) -> EncodeFn:
"""Encoder that interpolates from collocated to staggered grids."""
del dt, physics_specs # unused.
slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)
def encode_fn(inputs):
bc = boundaries.periodic_boundary_conditions(grid.ndim)
src_offset = grid.cell_center
pre_interp = tuple(
grids.GridVariable(
grids.GridArray(slice_last_fn(x), src_offset, grid), bc)
for x in inputs)
return tuple(interpolation.linear(c, offset)
for c, offset in zip(pre_interp, grid.cell_faces))
return encode_fn
@gin.register
def slice_last_state_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
time_axis=0,
) -> EncodeFn:
"""Generates encoder that returns last data slice along time axis."""
del grid, dt, physics_specs # unused.
def encode_fn(inputs):
return array_utils.slice_along_axis(inputs, time_axis, -1)
return encode_fn
@gin.register
def slice_last_n_state_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
n: int = gin.REQUIRED,
time_axis: int = 0,
) -> EncodeFn:
"""Generates encoder that returns last `n` data slices along last axis."""
del grid, dt, physics_specs # unused.
def encode_fn(inputs):
init_slice = array_utils.slice_along_axis(inputs, 0, slice(-n, None))
return jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, time_axis, -1), init_slice)
return encode_fn
@gin.register
def stack_last_n_state_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
n: int = gin.REQUIRED,
time_axis: int = 0,
) -> EncodeFn:
"""Generates encoder that stacks last `n` inputs slices along last axis."""
del grid, dt, physics_specs # unused.
def encode_fn(inputs):
inputs = array_utils.slice_along_axis(inputs, 0, slice(-n, None))
inputs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, time_axis, -1), inputs)
return array_utils.concat_along_axis(jax.tree_util.leaves(inputs), axis=-1)
return encode_fn
@gin.register
def latent_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
tower_factory: TowerFactory,
num_latent_dims: int,
n_frames: int,
time_axis: int = 0,
):
"""Generates trainable encoder that maps inputs to a latent representation.
Encoder first stacks last `n_frames` time slices in input trajectory along
channels and then applies a network specified by a `tower_factory` to obtain
a latent field representation with `num_latent_dims` channel dimensions.
Args:
grid: grid representing spatial discritization of the system.
dt: time step to use for time evolution.
physics_specs: physical parameters of the simulation.
tower_factory: factory that produces trainable tower network module.
num_latent_dims: number of channels to have in latent representation.
n_frames: number of last frames in input trajectory to use for encoding.
time_axis: axis in input trajectory that correspond to time.
Returns:
encode function that maps input trajectory `inputs` to a latent field
representation with `num_latent_dims`. Note that depending on the tower used
the spatial dimension of the representation might differ from `inputs`.
"""
stack_inputs_fn = stack_last_n_state_encoder(
grid, dt, physics_specs, n_frames, time_axis)
def encode_fn(inputs):
inputs = stack_inputs_fn(inputs)
encoder_tower = tower_factory(num_latent_dims, grid.ndim, name='encoder')
return encoder_tower(inputs)
return hk.to_module(encode_fn)()
@gin.register
def aligned_latent_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
tower_factory: TowerFactory,
num_latent_dims: int,
n_frames: int,
time_axis: int = 0,
data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
):
"""Latent encoder that decodes to GridVariables."""
data_offsets = data_offsets or grid.cell_faces
stack_inputs_fn = stack_last_n_state_encoder(
grid, dt, physics_specs, n_frames, time_axis)
def encode_fn(inputs):
bc = boundaries.periodic_boundary_conditions(grid.ndim)
inputs = stack_inputs_fn(inputs)
encoder_tower = tower_factory(num_latent_dims, grid.ndim, name='encoder')
raw_outputs = encoder_tower(inputs)
split_outputs = [raw_outputs[..., i] for i in range(raw_outputs.shape[-1])]
return tuple(
grids.GridVariable(grids.GridArray(x, offset, grid), bc)
for x, offset in zip(split_outputs, data_offsets))
return hk.to_module(encode_fn)()
@gin.register
def vorticity_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
"""Maps velocity to vorticity."""
del dt, physics_specs, data_offsets # unused.
slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)
def encode_fn(inputs):
u, v = inputs
u, v = slice_last_fn(u), slice_last_fn(v)
uhat, vhat = jnp.fft.rfft2(u), jnp.fft.rfft2(v)
kx, ky = grid.rfft_mesh()
vorticity_hat = 2j * jnp.pi * (vhat * kx - uhat * ky)
# TODO(dresdner) main difference is that the output is ifft'ed.
# TODO(dresdner) and also that the output has a channel dim.
return jnp.fft.irfft2(vorticity_hat)[..., jnp.newaxis]
return encode_fn
@gin.register
def vorticity_velocity_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
"""Maps velocity to [velocity; vorticity]."""
del dt, physics_specs, data_offsets # unused.
slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)
ifft = jnp.fft.irfft2
def encode_fn(inputs):
u, v = inputs
u, v = slice_last_fn(u), slice_last_fn(v)
uhat, vhat = jnp.fft.rfft2(u), jnp.fft.rfft2(v)
kx, ky = grid.rfft_mesh()
vorticity_hat = 2j * jnp.pi * (vhat * kx - uhat * ky)
return jnp.stack([u, v, ifft(vorticity_hat)], axis=-1)
return encode_fn
@gin.register
def spectral_vorticity_encoder(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
"""Generates encoder that wraps last data slice as GridVariables."""
del dt, physics_specs, data_offsets # unused.
slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)
def encode_fn(inputs):
u, v = inputs
u, v = slice_last_fn(u), slice_last_fn(v)
uhat, vhat = jnp.fft.rfft2(u), jnp.fft.rfft2(v)
kx, ky = grid.rfft_mesh()
vorticity = 2j * jnp.pi * (vhat * kx - uhat * ky)
return vorticity
return encode_fn
"""Implementations of equation modules."""
from typing import Any, Callable, Tuple
import gin
import haiku as hk
import jax
import jax.numpy as jnp
from jax_cfd import spectral
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import equations
from jax_cfd.base import grids
from jax_cfd.ml import advections
from jax_cfd.ml import diffusions
from jax_cfd.ml import forcings
from jax_cfd.ml import networks # pylint: disable=unused-import
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import pressures
from jax_cfd.ml import time_integrators
from jax_cfd.spectral import utils as spectral_utils
ConvectionModule = advections.ConvectionModule
DiffuseModule = diffusions.DiffuseModule
DiffusionSolveModule = diffusions.DiffusionSolveModule
ForcingModule = forcings.ForcingModule
PressureModule = pressures.PressureModule
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
# TODO(dkochkov) move diffusion to modular_navier_stokes after b/160947162.
@gin.register(denylist=("grid", "dt", "physics_specs"))
def semi_implicit_navier_stokes(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
diffusion_module: DiffuseModule = diffusions.diffuse,
**kwargs,
):
"""Semi-implicit navier stokes solver compatible with explicit diffusion."""
diffusion = diffusion_module(grid, dt, physics_specs)
step_fn = equations.semi_implicit_navier_stokes(
diffuse=diffusion, grid=grid, dt=dt, **kwargs)
return hk.to_module(step_fn)()
@gin.register(denylist=("grid", "dt", "physics_specs"))
def implicit_diffusion_navier_stokes(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
diffusion_module: DiffusionSolveModule = diffusions.solve_fast_diag,
**kwargs
):
"""Implicit navier stokes solver compatible with implicit diffusion."""
diffusion = diffusion_module(grid, dt, physics_specs)
step_fn = equations.implicit_diffusion_navier_stokes(
diffusion_solve=diffusion, grid=grid, dt=dt, **kwargs)
return hk.to_module(step_fn)()
@gin.register(denylist=("grid", "dt", "physics_specs"))
def modular_spectral_step_fn(
grid,
dt,
physics_specs,
do_filter_step=False,
time_stepper=spectral.time_stepping.crank_nicolson_rk4,
):
"""Returns a spectral solver for Forced Navier-Stokes flows."""
eq = spectral.equations.NavierStokes2D(
physics_specs.viscosity,
grid,
drag=physics_specs.drag,
forcing_fn=physics_specs.forcing_module,
smooth=physics_specs.smooth)
step_fn = time_stepper(eq, dt)
if do_filter_step:
# lambdas don't place nice with gin config.
def ret(vhat):
v = jnp.fft.irfft2(step_fn(vhat)) # TODO(dresdner) unnecessary fft's
return jnp.fft.rfft2(spectral_utils.exponential_filter(v))
else:
ret = step_fn
return hk.to_module(ret)()
@gin.configurable(denylist=("grid", "dt", "physics_specs"))
def modular_navier_stokes_model(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
equation_solver=implicit_diffusion_navier_stokes,
convection_module: ConvectionModule = advections.self_advection,
pressure_module: PressureModule = pressures.fast_diagonalization,
acceleration_modules=(),
):
"""Returns an incompressible Navier-Stokes time step model.
This model is derived from standard components of numerical solvers that could
be replaced with learned components. Note that diffusion module is specified
in the equation_solver due to differences in implicit/explicit schemes.
Args:
grid: grid on which the Navier-Stokes equation is discretized.
dt: time step to use for time evolution.
physics_specs: physical parameters of the simulation module.
equation_solver: solver to call to create a time-stepping function.
convection_module: module to use to simulate convection.
pressure_module: module to use to perform pressure projection.
acceleration_modules: additional explicit terms to be adde to the equation
before the pressure projection step.
Returns:
A function that performs `steps` steps of the Navier-Stokes time dynamics.
"""
active_forcing_fn = physics_specs.forcing_module(grid)
def navier_stokes_step_fn(state):
"""Advances Navier-Stokes state forward in time."""
v = state
for u in v:
if not isinstance(u, grids.GridVariable):
raise ValueError(f"Expected GridVariable type, got {type(u)}")
convection = convection_module(grid, dt, physics_specs, v=v)
accelerations = [
acceleration_module(grid, dt, physics_specs, v=v)
for acceleration_module in acceleration_modules
]
forcing = forcings.sum_forcings(active_forcing_fn, *accelerations)
pressure_solve_fn = pressure_module(grid, dt, physics_specs)
step_fn = equation_solver(
grid=grid,
dt=dt,
physics_specs=physics_specs,
density=physics_specs.density,
viscosity=physics_specs.viscosity,
pressure_solve=pressure_solve_fn,
convect=convection,
forcing=forcing)
return step_fn(v)
return hk.to_module(navier_stokes_step_fn)()
@gin.register
def time_derivative_network_model(
grid: grids.Grid,
dt: float,
physics_specs: Any,
derivative_modules: Tuple[Callable, ...],
time_integrator=time_integrators.euler_integrator,
):
"""Returns a ML model that performs time stepping by time integration.
Note: the model state is assumed to be a stack of observable values
along the last axis.
Args:
grid: grid specifying spatial discretization of the physical system.
dt: time step to use for time evolution.
physics_specs: physical parameters of the simulation module.
derivative_modules: tuple of modules that are used sequentially to compute
unforced time derivative of the input state, which is then integrated.
time_integrator: time integration scheme to use.
Returns:
`step_fn` that advances the input state forward in time by `dt`.
"""
active_forcing_fn = physics_specs.forcing_module(grid)
def step_fn(state):
"""Advances `state` forward in time by `dt`."""
modules = [module(grid, dt, physics_specs) for module in derivative_modules]
def time_derivative_fn(x):
v = array_utils.split_axis(x, axis=-1) # Tuple[DeviceArray, ...]
v = tuple(grids.GridArray(u, o, grid) for u, o in zip(v, grid.cell_faces))
# TODO(pnorgaard) Explicitly specify boundary conditions for ML model
bc = boundaries.periodic_boundary_conditions(grid.ndim)
v = tuple(grids.GridVariable(u, bc) for u in v)
forcing_scalars = jnp.stack(
[a.data for a in active_forcing_fn(v)], axis=-1)
# TODO(dkochkov) consider conditioning on the forcing terms.
for module_fn in modules:
x = module_fn(x)
return x + forcing_scalars
time_derivative_module = hk.to_module(time_derivative_fn)()
out, _ = time_integrator(time_derivative_module, state, dt, 1)
return out
return hk.to_module(step_fn)()
@gin.register
def identity_model(grid, dt, physics_specs):
"""A model that just returns the original state."""
del grid, dt, physics_specs
def step_fn(state):
return state
return step_fn
@gin.register
def learned_corrector(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
base_solver_module: Callable,
corrector_module: Callable,
):
"""Returns a model that uses base solver with ML correction step."""
# Idea similar to solver in the loop in https://arxiv.org/abs/2007.00016 and
# learned corrector in https://arxiv.org/pdf/2102.01010.pdf.
base_solver = base_solver_module(grid, dt, physics_specs)
corrector = corrector_module(grid, dt, physics_specs)
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(next_state)
return jax.tree_util.tree_map(lambda x, y: x + y, next_state, corrections)
return hk.to_module(step_fn)()
@gin.register
def learned_corrector_v2(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
base_solver_module: Callable,
corrector_module: Callable,
):
"""Like learned_corrector, but based on the input rather than output state."""
base_solver = base_solver_module(grid, dt, physics_specs)
corrector = corrector_module(grid, dt, physics_specs)
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(state)
return jax.tree_util.tree_map(lambda x, y: x + dt * y, next_state, corrections)
return hk.to_module(step_fn)()
@gin.register
def learned_corrector_v3(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
base_solver_module: Callable,
corrector_module: Callable,
):
"""Like learned_corrector, but based on input & output states."""
base_solver = base_solver_module(grid, dt, physics_specs)
corrector = corrector_module(grid, dt, physics_specs)
def step_fn(state):
next_state = base_solver(state)
corrections = corrector(tuple(state) + tuple(next_state))
return jax.tree_util.tree_map(lambda x, y: x + dt * y, next_state, corrections)
return hk.to_module(step_fn)()
"""Tests for models_v2.equations."""
import copy
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import gin
import haiku as hk
import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.ml import equations
from jax_cfd.ml import physics_specifications
GRIDS = [
grids.Grid((32, 32), domain=((0, 2 * jnp.pi),) * 2),
grids.Grid((8, 8, 8), domain=((0, 2 * jnp.pi),) * 3),
]
C_INTERPOLATION_MODULES = [
'interpolations.upwind',
'interpolations.linear',
'interpolations.lax_wendroff',
'FusedLearnedInterpolation',
'IndividualLearnedInterpolation',
]
PRESSURE_MODULES = [
'pressures.fast_diagonalization',
'pressures.conjugate_gradient',
]
FORCING_MODULES = [
'forcings.filtered_linear_forcing',
'forcings.kolmogorov_forcing',
'forcings.taylor_green_forcing',
]
FORCING_SCALE = .1
def navier_stokes_test_parameters():
product = itertools.product(GRIDS,
C_INTERPOLATION_MODULES,
PRESSURE_MODULES,
FORCING_MODULES)
parameters = []
for grid, interpolation, pressure, forcing in product:
name = '_'.join([module.split('.')[-1]
for module in (interpolation, pressure, forcing)])
shape = 'x'.join(str(s) for s in grid.shape)
name = f'{name}_{shape}'
parameters.append(dict(
testcase_name=name,
c_interpolation_module=interpolation,
pressure_module=pressure,
grid=grid,
forcing_module=forcing,
convection_module='advections.self_advection',
u_interpolation_module='interpolations.linear'))
return parameterized.named_parameters(*parameters)
def ml_test_parameters():
product = itertools.product(GRIDS, FORCING_MODULES)
parameters = []
for grid, forcing in product:
shape = 'x'.join(str(s) for s in grid.shape)
name = f'epd_{forcing.split(".")[-1]}_{shape}'
parameters.append(
dict(testcase_name=name, grid=grid, forcing_module=forcing))
return parameterized.named_parameters(*parameters)
class NavierStokesModulesTest(test_util.TestCase):
"""Integration tests for equations and its submodules."""
def _generate_inputs_and_outputs(self, config, grid):
gin.enter_interactive_mode()
gin.parse_config(config)
dt = 0.1
physics_specs = physics_specifications.get_physics_specs()
def step_fwd(x):
model = equations.modular_navier_stokes_model(
grid, dt, physics_specs)
return model(x)
step_model = hk.without_apply_rng(hk.transform(step_fwd))
inputs = []
for seed, offset in enumerate(grid.cell_faces):
rng_key = jax.random.PRNGKey(seed)
data = jax.random.uniform(rng_key, grid.shape, jnp.float32)
variable = grids.GridVariable(
array=grids.GridArray(data, offset, grid),
bc=boundaries.periodic_boundary_conditions(grid.ndim))
inputs.append(variable)
inputs = tuple(inputs)
rng = jax.random.PRNGKey(42)
with funcutils.init_context():
params = step_model.init(rng, inputs)
self.assertIsNotNone(params)
outputs = step_model.apply(params, inputs)
return inputs, outputs
@navier_stokes_test_parameters()
def test_all_modules(
self,
convection_module,
c_interpolation_module,
u_interpolation_module,
pressure_module,
forcing_module,
grid
):
"""Intgeration tests checking that `step_fn` produces expected shape."""
interp_module = 'advections.modular_advection'
ns_module_name = 'equations.modular_navier_stokes_model'
config = [
f'{interp_module}.c_interpolation_module = @{c_interpolation_module}',
f'{interp_module}.u_interpolation_module = @{u_interpolation_module}',
f'{ns_module_name}.convection_module = @{convection_module}',
f'{ns_module_name}.pressure_module = @{pressure_module}',
f'{forcing_module}.scale = {FORCING_SCALE}',
f'NavierStokesPhysicsSpecs.forcing_module = @{forcing_module}',
'NavierStokesPhysicsSpecs.density = 1.',
'NavierStokesPhysicsSpecs.viscosity = 0.1',
'get_physics_specs.physics_specs_cls = @NavierStokesPhysicsSpecs',
]
inputs, outputs = self._generate_inputs_and_outputs(config, grid)
for u_output, u_input in zip(outputs, inputs):
self.assertEqual(u_output.shape, u_input.shape)
def test_smagorinsky(self):
"""Tests that eddy viscosity models predict expected shapes."""
diffusion_solver = 'implicit_diffusion_navier_stokes'
evm_module_name = 'implicit_evm_solve_with_diffusion'
config = [
f'{diffusion_solver}.diffusion_module = @{evm_module_name}',
f'{evm_module_name}.viscosity_module = @eddy_viscosity_model',
'eddy_viscosity_model.viscosity_model = @smagorinsky_viscosity',
'smagorinsky_viscosity.cs = 0.2',
'NavierStokesPhysicsSpecs.forcing_module = @kolmogorov_forcing',
'NavierStokesPhysicsSpecs.density = 1.',
'NavierStokesPhysicsSpecs.viscosity = 0.1',
'get_physics_specs.physics_specs_cls = @NavierStokesPhysicsSpecs',
]
grid = GRIDS[0]
inputs, outputs = self._generate_inputs_and_outputs(config, grid)
for u_output, u_input in zip(outputs, inputs):
self.assertEqual(u_output.shape, u_input.shape)
def test_learned_viscosity_modules(self,):
"""Intgeration tests checking that `step_fn` produces expected shape."""
ns_module_name = 'equations.modular_navier_stokes_model'
model_gin_config = '\n'.join([
f'{ns_module_name}.pressure_module = @fast_diagonalization',
f'{ns_module_name}.convection_module = @self_advection',
f'{ns_module_name}.acceleration_modules = (@eddy_viscosity_model,)',
'eddy_viscosity_model.viscosity_model = @learned_scalar_viscosity',
'learned_scalar_viscosity.tower_factory = @MlpTowerFactory',
'MlpTowerFactory.num_hidden_units = 16',
'MlpTowerFactory.num_hidden_layers = 3',
f'{ns_module_name}.equation_solver = @semi_implicit_navier_stokes',
'semi_implicit_navier_stokes.diffusion_module = @diffuse',
'self_advection.advection_module = @modular_advection',
'modular_advection.u_interpolation_module = @linear',
'modular_advection.c_interpolation_module = @transformed',
'transformed.base_interpolation_module = @lax_wendroff',
'transformed.transformation = @tvd_limiter_transformation',
'NavierStokesPhysicsSpecs.forcing_module = @kolmogorov_forcing',
'NavierStokesPhysicsSpecs.density = 1.',
'NavierStokesPhysicsSpecs.viscosity = 0.1',
'get_physics_specs.physics_specs_cls = @NavierStokesPhysicsSpecs',
])
grid = GRIDS[0]
inputs, outputs = self._generate_inputs_and_outputs(model_gin_config, grid)
for u_input, u_output in zip(inputs, outputs):
self.assertEqual(u_input.shape, u_output.shape)
def test_alternate_implementation_consistency(self):
convection_module = 'advections.self_advection'
advection_module = 'advections.modular_self_advection'
interpolation_module = 'FusedLearnedInterpolation'
pressure_module = 'pressures.fast_diagonalization'
forcing_module = 'forcings.kolmogorov_forcing'
ns_module_name = 'equations.modular_navier_stokes_model'
grid = grids.Grid((32, 32), domain=((0, 2 * jnp.pi),) * 2)
config = [
f'{advection_module}.interpolation_module = @{interpolation_module}',
f'{convection_module}.advection_module = @{advection_module}',
f'{ns_module_name}.convection_module = @{convection_module}',
f'{ns_module_name}.pressure_module = @{pressure_module}',
f'{forcing_module}.scale = {FORCING_SCALE}',
'FusedLearnedInterpolation.tags = ("u", "c")',
f'NavierStokesPhysicsSpecs.forcing_module = @{forcing_module}',
'NavierStokesPhysicsSpecs.density = 1.',
'NavierStokesPhysicsSpecs.viscosity = 0.1',
'get_physics_specs.physics_specs_cls = @NavierStokesPhysicsSpecs',
]
_, outputs1 = self._generate_inputs_and_outputs(config, grid)
config2 = config + [
'FusedLearnedInterpolation.extract_patch_method = "conv"',
]
_, outputs2 = self._generate_inputs_and_outputs(config2, grid)
for out1, out2 in zip(outputs1, outputs2):
self.assertAllClose(out1, out2, rtol=1e-6)
config2 = config + [
'FusedLearnedInterpolation.fuse_constraints = True',
]
_, outputs2 = self._generate_inputs_and_outputs(config2, grid)
for out1, out2 in zip(outputs1, outputs2):
self.assertAllClose(out1, out2, rtol=1e-6)
config2 = config + [
'FusedLearnedInterpolation.fuse_constraints = True',
'FusedLearnedInterpolation.fuse_patches = True',
]
_, outputs2 = self._generate_inputs_and_outputs(config2, grid)
for out1, out2 in zip(outputs1, outputs2):
self.assertAllClose(out1, out2, rtol=1e-6)
config2 = config + [
'FusedLearnedInterpolation.extract_patch_method = "conv"',
'FusedLearnedInterpolation.fuse_constraints = True',
'FusedLearnedInterpolation.tile_layout = (8, 1)',
]
_, outputs2 = self._generate_inputs_and_outputs(config2, grid)
for out1, out2 in zip(outputs1, outputs2):
self.assertAllClose(out1, out2, rtol=1e-6)
class MLModulesTest(test_util.TestCase):
def _generate_inputs_and_outputs(self, config, grid):
gin.enter_interactive_mode()
gin.parse_config(config)
dt = 0.1
physics_specs = physics_specifications.get_physics_specs()
def step_fwd(x):
# deepcopy triggers evaluation of references
derivative_modules = copy.deepcopy(gin.query_parameter(
'time_derivative_network_model.derivative_modules'))
model = equations.time_derivative_network_model(
grid, dt, physics_specs, derivative_modules)
return model(x)
step_model = hk.without_apply_rng(hk.transform(step_fwd))
inputs = []
for seed, _ in enumerate(grid.cell_faces):
rng_key = jax.random.PRNGKey(seed)
data = jax.random.uniform(rng_key, grid.shape, jnp.float32)
inputs.append(data)
inputs = jnp.stack(inputs, axis=-1)
rng = jax.random.PRNGKey(42)
with funcutils.init_context():
params = step_model.init(rng, inputs)
self.assertIsNotNone(params)
outputs = step_model.apply(params, inputs)
return inputs, outputs
@ml_test_parameters()
def test_epd_modules(
self,
forcing_module,
grid
):
"""Intgeration tests checking that `step_fn` produces expected shape."""
ndim = grid.ndim
latent_dims = 20
ml_module_name = 'time_derivative_network_model'
epd_towers = '(@enc/tower_module, @proc/tower_module, @dec/tower_module,)'
config = [
f'enc/tower_module.num_output_channels = {latent_dims}',
'enc/tower_module.tower_factory = @forward_tower_factory',
'proc/tower_module.tower_factory = @residual_block_tower_factory',
f'dec/tower_module.num_output_channels = {ndim}',
'dec/tower_module.tower_factory = @forward_tower_factory',
f'{ml_module_name}.derivative_modules = {epd_towers}',
f'{forcing_module}.scale = {FORCING_SCALE}',
f'NavierStokesPhysicsSpecs.forcing_module = @{forcing_module}',
'NavierStokesPhysicsSpecs.density = 1.',
'NavierStokesPhysicsSpecs.viscosity = 0.1',
'get_physics_specs.physics_specs_cls = @NavierStokesPhysicsSpecs',
]
inputs, outputs = self._generate_inputs_and_outputs(config, grid)
for u_input, u_output in zip(inputs, outputs):
self.assertEqual(u_input.shape, u_output.shape)
if __name__ == '__main__':
# Temporarily disable async dispatch on JAX CPU due to tsan error.
jax.config.update('jax_cpu_enable_async_dispatch', False)
absltest.main()
"""Components that apply forcing. See jax_cfd.base.forcings for forcing API."""
from typing import Callable
from typing import Optional, Tuple
import gin
from jax import numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import equations
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.spectral import utils as spectral_utils
ForcingFn = forcings.ForcingFn
ForcingModule = Callable[..., ForcingFn]
def sum_forcings(*forces: ForcingFn) -> ForcingFn:
"""Sum multiple forcing functions."""
def forcing(v):
return equations.sum_fields(*[forcing(v) for forcing in forces])
return forcing
@gin.register
def filtered_linear_forcing(grid: grids.Grid,
scale: float,
lower_wavenumber: float = 0,
upper_wavenumber: float = 4) -> ForcingFn:
return forcings.filtered_linear_forcing(lower_wavenumber,
upper_wavenumber,
coefficient=scale,
grid=grid)
@gin.register
def linear_forcing(grid: grids.Grid,
scale: float) -> ForcingFn:
return forcings.linear_forcing(grid, scale)
@gin.register
def spectral_kolmogorov_forcing(grid):
return forcings.kolmogorov_forcing(
grid, 1.0, k=4, swap_xy=False, offsets=((0.0, 0.0), (0.0, 0.0)))
@gin.register
def vorticity_space_forcing(grid: grids.Grid, forcing_module: ForcingModule):
forcing_fn = forcing_module(grid, offsets=((0.0, 0.0), (0.0, 0.0)))
velocity_solve = spectral_utils.vorticity_to_velocity(grid)
kx, ky = grid.rfft_mesh()
fft, ifft = jnp.fft.rfft2, jnp.fft.irfft2
bc = boundaries.periodic_boundary_conditions(grid.ndim)
offset = (0.0, 0.0) # TODO(dresdner) do not hard code
def forcing_fn_ret(vorticity):
vorticity, = array_utils.split_axis(vorticity, axis=-1) # channel dim = 1
v = tuple(
grids.GridVariable(grids.GridArray(ifft(u), offset, grid), bc)
for u in velocity_solve(fft(vorticity)))
fhatu, fhatv = tuple(fft(u) for u in forcing_fn(v))
fhat_vorticity = 2j * jnp.pi * (fhatv * kx - fhatu * ky)
return ifft(fhat_vorticity)
return forcing_fn_ret
@gin.register
def kolmogorov_forcing(grid: grids.Grid, # pylint: disable=missing-function-docstring
scale: float = 0,
wavenumber: int = 2,
linear_coefficient: float = 0,
swap_xy: bool = False,
offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> ForcingFn:
force_fn = forcings.kolmogorov_forcing(
grid, scale, wavenumber, swap_xy, offsets=offsets)
if linear_coefficient != 0:
linear_force_fn = forcings.linear_forcing(grid, linear_coefficient)
force_fn = forcings.sum_forcings(force_fn, linear_force_fn)
return force_fn
@gin.register
def taylor_green_forcing(grid: grids.Grid,
scale: float = 0,
wavenumber: int = 2,
linear_coefficient: float = 0) -> ForcingFn:
force_fn = forcings.taylor_green_forcing(grid, scale, wavenumber)
if linear_coefficient != 0:
linear_force_fn = forcings.linear_forcing(grid, linear_coefficient)
force_fn = forcings.sum_forcings(force_fn, linear_force_fn)
return force_fn
@gin.register
def no_forcing(grid: grids.Grid) -> ForcingFn:
return forcings.no_forcing(grid)
"""Interpolation modules."""
import collections
import functools
from typing import Any, Callable, Tuple, Union
import gin
import jax.numpy as jnp
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.ml import layers
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import towers
import numpy as np
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = interpolation.InterpolationFn
InterpolationModule = Callable[..., InterpolationFn]
InterpolationTransform = Callable[..., InterpolationFn]
FluxLimiter = interpolation.FluxLimiter
StencilSizeFn = Callable[
[Tuple[int, ...], Tuple[int, ...], Any], Tuple[int, ...]]
@gin.register
class FusedLearnedInterpolation:
"""Learned interpolator that computes interpolation coefficients in 1 pass.
Interpolation function that has pre-computed interpolation
coefficients for a given velocity field `v`. It uses a collection of
`SpatialDerivativeFromLogits` modules and a single neural network that
produces logits for all expected interpolations. Interpolations are keyed by
`input_offset`, `target_offset` and an optional `tag`. The `tag` allows us to
perform multiple interpolations between the same `offset` and `target_offset`
with different weights.
"""
def __init__(
self,
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
v,
tags=(None,),
stencil_size: Union[int, StencilSizeFn] = 4,
tower_factory=towers.forward_tower_factory,
name='fused_learned_interpolation',
extract_patch_method='roll',
fuse_constraints=False,
fuse_patches=False,
constrain_with_conv=False,
tile_layout=None,
):
"""Constructs object and performs necessary pre-computate."""
del dt, physics_specs # unused.
derivative_orders = (0,) * grid.ndim
derivatives = collections.OrderedDict()
if isinstance(stencil_size, int):
stencil_size_fn = lambda *_: (stencil_size,) * grid.ndim
else:
stencil_size_fn = stencil_size
for u in v:
for target_offset in grids.control_volume_offsets(u):
for tag in tags:
key = (u.offset, target_offset, tag)
derivatives[key] = layers.SpatialDerivativeFromLogits(
stencil_size_fn(*key),
u.offset,
target_offset,
derivative_orders=derivative_orders,
steps=grid.step,
extract_patch_method=extract_patch_method,
tile_layout=tile_layout)
output_sizes = [deriv.subspace_size for deriv in derivatives.values()]
cnn_network = tower_factory(sum(output_sizes), grid.ndim, name=name)
inputs = jnp.stack([u.data for u in v], axis=-1)
all_logits = cnn_network(inputs)
if fuse_constraints:
self._interpolators = layers.fuse_spatial_derivative_layers(
derivatives, all_logits, fuse_patches=fuse_patches,
constrain_with_conv=constrain_with_conv)
else:
split_logits = jnp.split(all_logits, np.cumsum(output_sizes), axis=-1)
self._interpolators = {
k: functools.partial(derivative, logits=logits)
for (k, derivative), logits in zip(derivatives.items(), split_logits)
}
def __call__(self,
c: GridVariable,
offset: Tuple[int, ...],
v: GridVariableVector,
dt: float,
tag=None) -> GridVariable:
del dt # not used.
# TODO(dkochkov) Add decorator to expand/squeeze channel dim.
c = grids.GridVariable(
grids.GridArray(jnp.expand_dims(c.data, -1), c.offset, c.grid), c.bc)
# TODO(jamieas): Try removing the following line.
if c.offset == offset: return c
key = (c.offset, offset, tag)
interpolator = self._interpolators.get(key)
if interpolator is None:
raise KeyError(f'No interpolator for key {key}. '
f'Available keys: {list(self._interpolators.keys())}')
result = jnp.squeeze(interpolator(c.data), axis=-1)
return grids.GridVariable(
grids.GridArray(result, offset, c.grid), c.bc)
def _nearest_neighhbor_stencil_size_fn(
source_offset, target_offset, tag, stencil_size,
):
del tag # unused
return tuple(
1 if s == t else stencil_size
for s, t in zip(source_offset, target_offset)
)
@gin.register
def anisotropic_learned_interpolation(*args, stencil_size=2, **kwargs):
"""Like FusedLearnedInterpolation, but with anisotropic stencil."""
stencil_size_fn = functools.partial(
_nearest_neighhbor_stencil_size_fn, stencil_size=stencil_size,
)
return FusedLearnedInterpolation(
*args, stencil_size=stencil_size_fn, **kwargs
)
@gin.register
class IndividualLearnedInterpolation:
"""Trainable interpolation module.
This module uses a collection of SpatialDerivative modules that are applied
to inputs based on the combination of initial and target offsets. Currently
no symmetries are implemented and every new pair of offsets gets a separate
network.
"""
def __init__(
self,
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
v: GridArrayVector,
stencil_size=4,
tower_factory=towers.forward_tower_factory,
):
del v, dt, physics_specs # unused.
self._ndim = grid.ndim
self._tower_factory = functools.partial(tower_factory, ndim=grid.ndim)
self._stencil_sizes = (stencil_size,) * self._ndim
self._steps = grid.step
self._modules = {}
def _get_interpolation_module(self, offsets):
"""Constructs or retrieves a learned interpolation module."""
if offsets in self._modules:
return self._modules[offsets]
inputs_offset, target_offset = offsets
self._modules[offsets] = layers.SpatialDerivative(
self._stencil_sizes, inputs_offset, target_offset,
(0,) * self._ndim, self._tower_factory, self._steps)
return self._modules[offsets]
def __call__(
self,
c: GridVariable,
offset: Tuple[int, ...],
v: GridVariableVector,
dt: float,
) -> GridVariable:
"""Interpolates `c` to `offset`."""
del dt # not used.
if c.offset == offset: return c
offsets = (c.offset, offset)
c_input = jnp.expand_dims(c.data, axis=-1)
aux_inputs = [jnp.expand_dims(u.data, axis=-1) for u in v]
res = self._get_interpolation_module(offsets)(c_input, *aux_inputs)
return grids.GridVariable(
grids.GridArray(jnp.squeeze(res, axis=-1), offset, c.grid), c.bc)
@gin.register
def linear(*args, **kwargs):
del args, kwargs
return interpolation.linear
@gin.register
def upwind(*args, **kwargs):
del args, kwargs
return interpolation.upwind
@gin.register
def lax_wendroff(*args, **kwargs):
del args, kwargs
return interpolation.lax_wendroff
# TODO(dkochkov) make flux limiters configurable.
@gin.register
def tvd_limiter_transformation(
interpolation_fn: InterpolationFn,
limiter_fn: FluxLimiter = interpolation.van_leer_limiter,
) -> InterpolationFn:
"""Transformation function that applies flux limiter to `interpolation_fn`."""
return interpolation.apply_tvd_limiter(interpolation_fn, limiter_fn)
@gin.register
def transformed(
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
v: GridArrayVector,
base_interpolation_module: InterpolationModule = lax_wendroff,
transformation: InterpolationTransform = tvd_limiter_transformation,
) -> InterpolationFn:
"""Interpolation module that augments interpolation of the base module.
This module generates interpolation method that consists of that generated
by `base_interpolation_module` transformed by `transformation`. This allows
implementation of additional constraints such as TVD, in which case
`transformation` should apply a TVD limiter.
Args:
grid: grid on which the Navier-Stokes equation is discretized.
dt: time step to use for time evolution.
physics_specs: physical parameters of the simulation module.
v: input velocity field potentially used to pre-compute interpolations.
base_interpolation_module: base interpolation module to use.
transformation: transformation to apply to base interpolation function.
Returns:
Interpolation function.
"""
interpolation_fn = base_interpolation_module(grid, dt, physics_specs, v=v)
interpolation_fn = transformation(interpolation_fn)
return interpolation_fn
"""Custom neural-net layers for physics simulations."""
import functools
from typing import (
Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union,
)
import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.ml import layers_util
from jax_cfd.ml import tiling
import numpy as np
import scipy.linalg
Array = Union[np.ndarray, jax.Array]
IntOrSequence = Union[int, Sequence[int]]
class PeriodicConvGeneral(hk.Module):
"""General periodic convolution module."""
def __init__(
self,
base_convolution: Callable[..., Any],
output_channels: int,
kernel_shape: Tuple[int, ...],
rate: int = 1,
tile_layout: Optional[Tuple[int, ...]] = None,
name: str = 'periodic_conv_general',
**conv_kwargs: Any
):
"""Constructs PeriodicConvGeneral module.
We use `VALID` padding on `base_convolution` and explicit padding to achieve
the effect of periodic boundary conditions. This function computes
`paddings` and combines `jnp.pad` function calls with `base_convolution`
module to produce the dersired effect.
Args:
base_convolution: standard convolution module e.g. hk.Conv1D.
output_channels: number of output channels.
kernel_shape: shape of the kernel, compatible with `base_convolution`.
rate: dilation rate of the convolution.
tile_layout: optional layout for tiling spatial dimensions in a batch.
name: name of the module.
**conv_kwargs: additional arguments passed to `base_convolution`.
"""
super().__init__(name=name)
self._padding = []
for kernel_size in kernel_shape:
effective_kernel = kernel_size + (rate - 1) * (kernel_size - 1)
pad_left = effective_kernel // 2
self._padding.append((pad_left, effective_kernel - pad_left - 1))
self._tile_layout = tile_layout
self._conv_module = base_convolution(
output_channels=output_channels, kernel_shape=kernel_shape,
padding='VALID', rate=rate, **conv_kwargs)
def __call__(self, inputs):
return tiling.apply_convolution(
self._conv_module, inputs, self._tile_layout, self._padding)
class PeriodicConv1D(PeriodicConvGeneral):
"""Periodic convolution module in 1D."""
def __init__(
self,
output_channels: int,
kernel_shape: Tuple[int],
rate: int = 1,
tile_layout: Optional[Tuple[int]] = None,
name='periodic_conv_1d',
**conv_kwargs
):
"""Constructs PeriodicConv1D module."""
super().__init__(
base_convolution=hk.Conv1D,
output_channels=output_channels,
kernel_shape=kernel_shape,
rate=rate,
tile_layout=tile_layout,
name=name,
**conv_kwargs
)
class PeriodicConv2D(PeriodicConvGeneral):
"""Periodic convolution module in 2D."""
def __init__(
self,
output_channels: int,
kernel_shape: Tuple[int, int],
rate: int = 1,
tile_layout: Optional[Tuple[int, int]] = None,
name='periodic_conv_2d',
**conv_kwargs
):
"""Constructs PeriodicConv2D module."""
super().__init__(
base_convolution=hk.Conv2D,
output_channels=output_channels,
kernel_shape=kernel_shape,
rate=rate,
tile_layout=tile_layout,
name=name,
**conv_kwargs
)
class PeriodicConv3D(PeriodicConvGeneral):
"""Periodic convolution module in 3D."""
def __init__(
self,
output_channels: int,
kernel_shape: Tuple[int, int, int],
rate: int = 1,
tile_layout: Optional[Tuple[int, int, int]] = None,
name='periodic_conv_3d',
**conv_kwargs
):
"""Constructs PeriodicConv3D module."""
super().__init__(
base_convolution=hk.Conv3D,
output_channels=output_channels,
kernel_shape=kernel_shape,
rate=rate,
tile_layout=tile_layout,
name=name,
**conv_kwargs
)
class MirrorConvGeneral(hk.Module):
"""General periodic convolution module."""
def __init__(self,
base_convolution: Callable[..., Any],
output_channels: int,
kernel_shape: Tuple[int, ...],
rate: int = 1,
tile_layout: Optional[Tuple[int, ...]] = None,
name: str = 'mirror_conv_general',
**conv_kwargs: Any):
"""Constructs MirrorConvGeneral module.
We use `VALID` padding on `base_convolution` and explicit padding beyond
the boudaries. This function computes paddings` and combines `jnp.pad`
function calls with `base_convolution` module.
Args:
base_convolution: standard convolution module e.g. hk.Conv1D.
output_channels: number of output channels.
kernel_shape: shape of the kernel, compatible with `base_convolution`.
rate: dilation rate of the convolution.
tile_layout: optional layout for tiling spatial dimensions in a batch.
name: name of the module.
**conv_kwargs: additional arguments passed to `base_convolution`.
"""
super().__init__(name=name)
self._padding = np.zeros(2).astype('int')
for kernel_size in kernel_shape:
effective_kernel = kernel_size + (rate - 1) * (kernel_size - 1)
pad_left = effective_kernel // 2
self._padding += np.array([pad_left,
effective_kernel - pad_left - 1]).astype('int')
self._tile_layout = tile_layout
self._conv_module = base_convolution(
output_channels=output_channels, kernel_shape=kernel_shape,
padding='VALID', rate=rate, **conv_kwargs)
def _expand_var(self, var):
var = var.bc.pad_all(
var, (self._padding,) * var.grid.ndim, mode=boundaries.Padding.MIRROR)
return jnp.expand_dims(var.data, axis=-1)
def __call__(self, inputs):
input_data = tuple(self._expand_var(var) for var in inputs)
input_data = array_utils.concat_along_axis(
jax.tree_util.leaves(input_data), axis=-1)
outputs = self._conv_module(input_data)
outputs = array_utils.split_axis(outputs, -1)
outputs = tuple(
var_input.bc.impose_bc(
grids.GridArray(var, var_input.offset, var_input.grid))
for var, var_input in zip(outputs, inputs))
return outputs
class MirrorConv2D(MirrorConvGeneral):
"""Mirror convolution module in 2D."""
def __init__(self,
output_channels: int,
kernel_shape: Tuple[int, int],
rate: int = 1,
tile_layout: Optional[Tuple[int, int]] = None,
name='mirror_conv_2d',
**conv_kwargs):
"""Constructs MirrorConv2D module."""
super().__init__(
base_convolution=hk.Conv2D,
output_channels=output_channels,
kernel_shape=kernel_shape,
rate=rate,
tile_layout=tile_layout,
name=name,
**conv_kwargs)
class PeriodicConvTransposeGeneral(hk.Module):
"""General periodic transpose convolution module."""
def __init__(
self,
base_convolution: Callable[..., Any],
output_channels: int,
kernel_shape: Tuple[int, ...],
stride: int = 1,
tile_layout: Optional[Tuple[int, ...]] = None,
name: str = 'periodic_conv_transpose_general',
**conv_kwargs: Any
):
"""Constructs PeriodicConvTransposeGeneral module.
To achieve the effect of periodic convolutions we first pad the inputs at
the start of each spatial axis with wrap mode to ensure that the output
generated by the original slice of the inputs receive contributions from
periodic images when the `base_convolution` is applied. The
`base_convolution` is applied with `VALID` padding followed by slicing
to discard the boundary values. Additionally we perform a roll on the output
to avoid the drift of spatial axes. (in standard implementation of the
transposed convolutions the kernel applied to index [i] affects outputs at
indicdes [i: i + kernel_size]. We perceive the center of the affected field
as the spatial location of the output and hence shift it back by half of
the kernel size.)
Args:
base_convolution: standard transpose convolution module.
output_channels: number of output channels.
kernel_shape: shape of the kernel, compatible with `base_convolution`.
stride: stride to use in `base_convolution`.
tile_layout: optional layout for tiling spatial dimensions in a batch.
name: name of the module.
**conv_kwargs: additional arguments passed to `base_convolution`.
"""
if tile_layout is not None:
raise NotImplementedError(
"tile_layout doesn't work yet for transpose convolutions")
super().__init__(name=name)
self._stride = stride
self._kernel_shape = kernel_shape
self._padding = []
self._roll_shifts = []
for kernel_size in kernel_shape:
# left pad should be large enough so that contribution from the leftmost
# element just affect the output of the input's original first index: i.e.
# stride * left_pad (output of the first index) should be less or equal to
# kernel_size (last affected value of the input)
pad_left = kernel_size // stride + 1
self._padding.append((pad_left, 0))
# we shift by half a kernel size at the end to recover spatial alignment.
self._roll_shifts.append(-((kernel_size - 1) // 2))
self._tile_layout = tile_layout
self._conv_module = base_convolution(
output_channels=output_channels, kernel_shape=kernel_shape,
stride=stride, padding='VALID', **conv_kwargs)
def __call__(self, inputs):
"""Applies PeriodicTransposeConvolution to `inputs`.
Args:
inputs: array with spatial and channel axes to which
PeriodicTransposeConvolution is applied.
Returns:
`inputs` convolved with the kernel of the module with periodic padding.
"""
ndim = len(self._kernel_shape)
output_slice = []
for axis, (left_pad, _) in enumerate(self._padding[:ndim]):
axis_size = inputs.shape[axis]
output_start = self._stride * left_pad
output_end = self._stride * (axis_size + left_pad)
output_slice.append(slice(output_start, output_end))
output_slice.append(slice(None, None))
output = tiling.apply_convolution(
self._conv_module, inputs, self._tile_layout, self._padding)
sliced_output = output[tuple(output_slice)]
return jnp.roll(sliced_output, self._roll_shifts, list(range(ndim)))
class PeriodicConvTranspose1D(PeriodicConvTransposeGeneral):
"""Periodic transpose convolution module in 1D."""
def __init__(
self,
output_channels: int,
kernel_shape: Tuple[int],
stride: int = 1,
tile_layout: Optional[Tuple[int]] = None,
name='periodic_conv_transpose_1d',
**conv_kwargs
):
"""Constructs PeriodicConv1D module."""
super().__init__(
base_convolution=hk.Conv1DTranspose,
output_channels=output_channels,
kernel_shape=kernel_shape,
stride=stride,
tile_layout=tile_layout,
name=name,
**conv_kwargs
)
class PeriodicConvTranspose2D(PeriodicConvTransposeGeneral):
"""Periodic transpose convolution module in 2D."""
def __init__(
self,
output_channels: int,
kernel_shape: Tuple[int, int],
stride: int = 1,
tile_layout: Optional[Tuple[int, int]] = None,
name='periodic_conv_transpose_2d',
**conv_kwargs
):
"""Constructs PeriodicConv2D module."""
super().__init__(
base_convolution=hk.Conv2DTranspose,
output_channels=output_channels,
kernel_shape=kernel_shape,
stride=stride,
tile_layout=tile_layout,
name=name,
**conv_kwargs
)
class PeriodicConvTranspose3D(PeriodicConvTransposeGeneral):
"""Periodic transpose convolution module in 3D."""
def __init__(
self,
output_channels: int,
kernel_shape: Tuple[int, int, int],
stride: int = 1,
tile_layout: Optional[Tuple[int, int, int]] = None,
name='periodic_conv_transpose_3d',
**conv_kwargs
):
"""Constructs PeriodicConv3D module."""
super().__init__(
base_convolution=hk.Conv3DTranspose,
output_channels=output_channels,
kernel_shape=kernel_shape,
stride=stride,
tile_layout=tile_layout,
name=name,
**conv_kwargs
)
def rescale_to_range(
inputs: Array,
min_value: float,
max_value: float,
axes: Tuple[int, ...]
) -> Array:
"""Rescales inputs to [min_value, max_value] range.
Note that this function performs input dependent transformation, which might
not be suitable for models that aim to learn different dynamics for different
scales.
Args:
inputs: array to be rescaled to [min_value, max_value] range.
min_value: value to which the smallest entry of `inputs` is mapped to.
max_value: value to which the largest entry of `inputs` is mapped to.
axes: `inputs` axes across which we search for smallest and largest values.
Returns:
`inputs` rescaled to [min_value, max_value] range.
"""
inputs_max = jnp.max(inputs, axis=axes, keepdims=True)
inputs_min = jnp.min(inputs, axis=axes, keepdims=True)
scale = (inputs_max - inputs_min) / (max_value - min_value)
return (inputs - inputs_min) / scale + min_value
class NonPeriodicConvGeneral(hk.Module):
"""General periodic convolution module."""
def __init__(self,
base_convolution: Callable[..., Any],
output_channels: int,
kernel_shape: Tuple[int, ...],
rate: int = 1,
name: str = 'periodic_conv_general',
**conv_kwargs: Any):
"""Constructs NonPeriodicConvGeneral module similar to a Periodic one.
Args:
base_convolution: standard convolution module e.g. hk.Conv1D.
output_channels: number of output channels.
kernel_shape: shape of the kernel, compatible with `base_convolution`.
rate: dilation rate of the convolution.
name: name of the module.
**conv_kwargs: additional arguments passed to `base_convolution`.
"""
super().__init__(name=name)
self._conv_module = base_convolution(
output_channels=output_channels,
kernel_shape=kernel_shape,
padding='VALID',
rate=rate,
**conv_kwargs)
def __call__(self, inputs):
output = self._conv_module(jnp.expand_dims(inputs, axis=0))
return jnp.squeeze(output, axis=0)
class NonPeriodicConv1D(NonPeriodicConvGeneral):
"""Periodic convolution module in 1D."""
def __init__(self,
output_channels: int,
kernel_shape: Tuple[int, ...],
rate: int = 1,
name='periodic_conv_1d',
**conv_kwargs):
"""Constructs PeriodicConv1D module."""
super().__init__(
base_convolution=hk.Conv1D,
output_channels=output_channels,
kernel_shape=kernel_shape,
rate=rate,
name=name,
**conv_kwargs)
class PolynomialConstraint():
"""Module that parametrizes coefficients of polynomially accurate derivatives.
Generates stencil coefficients that are guaranteed to approximate derivative
of `derivative_orders[i]` along ith dimension with polynomial accuracy of
`accuracy_order` order. The approximation is enforced by taking a linear
superposition of the nullspace of linear constraints combined with a bias
solution, which can be either specified directly using `bias` argument or
generated automatically using `layers_util.polynomial_accuracy_coefficients`.
"""
def __init__(
self,
stencils: Sequence[np.ndarray],
derivative_orders: Tuple[int, ...],
method: layers_util.Method,
steps: Tuple[float, ...],
accuracy_order: int = 1,
bias_accuracy_order: int = 1,
bias: Optional[Array] = None,
precision: lax.Precision = lax.Precision.HIGHEST
):
"""Constructs the object.
Args:
stencils: sequence of 1d stencils, one per grid dimension
derivative_orders: derivative orders along corresponding directions.
method: discretization method (finite volumes or finite differences).
steps: spatial separations between the adjacent cells.
accuracy_order: order to which polynomial accuracy is enforced.
bias_accuracy_order: integer order of polynomial accuracy to use for the
bias term. Only used if bias is not provided.
bias: np.ndarray of shape (grid_size,) to which zero-vectors will be
mapped. Must satisfy polynomial accuracy to the requested order. By
default, we use standard low-order coefficients for the given grid.
precision: numerical precision for matrix multplication. Only relevant on
TPUs.
"""
self.precision = precision
grid_steps = {*steps}
if len(grid_steps) != 1:
raise ValueError('nonuniform steps not supported by PolynomialConstraint')
grid_step, = grid_steps
# stencil coefficients `c` satisfying `constraint_matrix @ c = rhs`
# satisfies polynomial accuracy constraint of the given order
constraint_matrix, rhs = layers_util.polynomial_accuracy_constraints(
stencils, method, derivative_orders, accuracy_order, grid_step)
if bias is None:
bias_grid = layers_util.polynomial_accuracy_coefficients(
stencils, method, derivative_orders, bias_accuracy_order, grid_step)
bias = bias_grid.ravel()
self.bias = bias
norm = np.linalg.norm(np.dot(constraint_matrix, bias) - rhs)
if norm > 1e-8:
raise ValueError('invalid bias, not in nullspace')
# https://en.wikipedia.org/wiki/Kernel_(linear_algebra)#Nonhomogeneous_systems_of_linear_equations
_, _, v = np.linalg.svd(constraint_matrix)
nullspace_size = constraint_matrix.shape[1] - constraint_matrix.shape[0]
if not nullspace_size:
raise ValueError(
'there is only one valid solution accurate to this order')
# nullspace from the SVD is always normalized such that its singular values
# are 1 or 0, which means it's actually independent of the grid spacing.
self._nullspace_size = nullspace_size
self.nullspace = v[-nullspace_size:]
self.nullspace /= (grid_step**np.array(derivative_orders)).prod()
@property
def subspace_size(self) -> int:
"""Returns the size of the coefficients subspace with desired accuracy."""
return self._nullspace_size
def __call__(
self,
inputs: Array
) -> Array:
"""Returns polynomially accurate coefficients parametrized by `inputs`.
Args:
inputs: array whose last dimension represents linear superposition of
valid polynomially accurate coefficients. Last dimension must be equal
to `subspace_size`.
Returns:
array whose last dimension represents valid coefficients that approximate
`derivate_orders` with polynomial accuracy on a stencil specified in
`stencils` at position (0.,) * ndims.
"""
return self.bias + jnp.tensordot(inputs, self.nullspace, axes=[-1, 0],
precision=self.precision)
class StencilCoefficients():
"""Module that approximates stencil coefficients with polynomial accuracy.
Generates stencil coefficients that approximate a spatial derivative of
order `derivative_orders[i]` along i'th dimension by combining a trainable
model generated by `tower_factory` and `PolynomilConstraint` layer.
"""
def __init__(
self,
stencils: Sequence[np.ndarray],
derivative_orders: Tuple[int, ...],
tower_factory: Callable[[int], Callable[..., Any]],
steps: Tuple[float, ...],
method: layers_util.Method = layers_util.Method.FINITE_VOLUME,
**kwargs: Any,
):
"""Constructs the object.
Args:
stencils: sequence of 1d stencils, one per grid dimension
derivative_orders: derivative orders along corresponding directions.
tower_factory: callable that constructs a neural network with specified
number of output channels and the same spatial resolution as its inputs.
steps: spatial separations between the adjacent cells.
method: discretization method passed to PolynomialConstraint.
**kwargs: additional arguments to be passed to PolynomialConstraint
constructor.
"""
self._polynomial_constraint = PolynomialConstraint(
stencils, derivative_orders, method, steps, **kwargs)
self._tower = tower_factory(self._polynomial_constraint.subspace_size)
def __call__(self, inputs: Array, **kwargs) -> Array:
"""Returns coefficients approximating derivative conditioned on `inputs`."""
parametrization = self._tower(inputs, **kwargs)
return self._polynomial_constraint(parametrization)
class SpatialDerivativeFromLogits:
"""Module that transforms logits to polynomially accurate derivatives.
Applies `PolynomialConstraint` layer to input logits and combines the
resulting coefficients with basis. Compared to `SpatialDerivative`, this
module does not compute `logits`, but takes them as an argument.
"""
def __init__(
self,
stencil_shape: Tuple[int, ...],
input_offset: Tuple[float, ...],
target_offset: Tuple[float, ...],
derivative_orders: Tuple[int, ...],
steps: Tuple[float, ...],
extract_patch_method: str = 'roll',
tile_layout: Optional[Tuple[int, ...]] = None,
method: layers_util.Method = layers_util.Method.FINITE_VOLUME,
):
self.stencil_shape = stencil_shape
self.roll, shift = layers_util.get_roll_and_shift(
input_offset, target_offset)
stencils = layers_util.get_stencils(stencil_shape, shift, steps)
self.constraint = PolynomialConstraint(
stencils, derivative_orders, method, steps)
self._extract_patch_method = extract_patch_method
self.tile_layout = tile_layout
@property
def subspace_size(self) -> int:
return self.constraint.subspace_size
@property
def stencil_size(self) -> int:
return int(np.prod(self.stencil_shape))
def _validate_logits(self, logits):
if logits.shape[-1] != self.subspace_size:
raise ValueError('The last dimension of `logits` did not match subspace '
f'size; {logits.shape[-1]} vs. {self.subspace_size}')
def extract_patches(self, inputs):
rolled = jnp.roll(inputs, self.roll)
patches = layers_util.extract_patches(
rolled, self.stencil_shape,
self._extract_patch_method, self.tile_layout)
return patches
@functools.partial(jax.named_call, name='SpatialDerivativeFromLogits')
def __call__(self, inputs, logits):
self._validate_logits(logits)
coefficients = self.constraint(logits)
patches = self.extract_patches(inputs)
return layers_util.apply_coefficients(coefficients, patches)
T = TypeVar('T')
def fuse_spatial_derivative_layers(
derivatives: Dict[T, SpatialDerivativeFromLogits],
all_logits: jnp.ndarray,
*,
constrain_with_conv: bool = False,
fuse_patches: bool = False,
) -> Dict[T, Callable[[jnp.ndarray], jnp.ndarray]]:
"""Evaluate spatial derivatives by fusing together constraints.
Despite the additional calculation, this can be faster on TPUs because the
full block diagonal constraint matrix is small enough to fit within a 128x128
matrix.
Args:
derivatives: mapping from key to SpatialDerivativeFromLogits.
all_logits: stacked logits to use as input into spatial derivatives.
constrain_with_conv: whether to constrain with a 1x1 convolution instead of
direct matrix multiplication.
fuse_patches: whether to also fuse the extraction of patches.
Returns:
Functions that when applied evaluate derivatives.
"""
joint_bias = jnp.concatenate(
[derivative.constraint.bias for derivative in derivatives.values()])
joint_nullspace = scipy.linalg.block_diag(
*[deriv.constraint.nullspace for deriv in derivatives.values()]
)
precision, = {deriv.constraint.precision for deriv in derivatives.values()}
tile_layout, = {deriv.tile_layout for deriv in derivatives.values()}
if constrain_with_conv:
ndim = len(tile_layout)
kernel = jnp.expand_dims(
joint_nullspace.astype(np.float32), axis=tuple(range(ndim)))
all_coefficients = joint_bias + layers_util.periodic_convolution(
all_logits, kernel, tile_layout=tile_layout, precision=precision)
else:
if tile_layout is not None:
all_logits = tiling.space_to_batch(all_logits, tile_layout)
all_coefficients = joint_bias + jnp.tensordot(
all_logits, joint_nullspace, axes=[-1, 0], precision=precision)
if tile_layout is not None:
all_coefficients = tiling.batch_to_space(all_coefficients, tile_layout)
stencil_sizes = [deriv.stencil_size for deriv in derivatives.values()]
coefficients_list = jnp.split(
all_coefficients, np.cumsum(stencil_sizes), axis=-1)
coefficients_map = dict(zip(derivatives, coefficients_list))
stencil_shapes = [deriv.stencil_shape for k, deriv in derivatives.items()]
for k, deriv in derivatives.items():
if any(r != 0 for r in deriv.roll):
raise ValueError(f'derivative {k} uses roll: {deriv.roll}')
@functools.partial(jax.named_call, name='evaluate_derivatives')
def evaluate(key, inputs):
if fuse_patches:
all_patches = layers_util.fused_extract_patches(
inputs, stencil_shapes, tile_layout)
all_terms = all_coefficients * all_patches
split_terms = jnp.split(all_terms, np.cumsum(stencil_sizes), axis=-1)
index = list(derivatives).index(key)
return jnp.sum(split_terms[index], axis=-1, keepdims=True)
else:
patches = derivatives[key].extract_patches(inputs)
return layers_util.apply_coefficients(coefficients_map[key], patches)
return {k: functools.partial(evaluate, k) for k in derivatives}
class SpatialDerivative:
"""Module that learns spatial derivative with polynomial accuracy.
Combines StencilCoefficients with extract_stencils to construct a trainable
model that approximates spatial derivative.
"""
def __init__(
self,
stencil_shape: Tuple[int, ...],
input_offset: Tuple[float, ...],
target_offset: Tuple[float, ...],
derivative_orders: Tuple[int, ...],
tower_factory: Callable[[int], Callable[..., Any]],
steps: Tuple[float, ...],
extract_patch_method: str = 'roll',
tile_layout: Optional[Tuple[int, ...]] = None,
):
self._stencil_shape = stencil_shape
self._roll, self._shift = layers_util.get_roll_and_shift(
input_offset, target_offset)
stencils = layers_util.get_stencils(stencil_shape, self._shift, steps)
self._coefficients_module = StencilCoefficients(
stencils, derivative_orders, tower_factory, steps)
self._extract_patch_method = extract_patch_method
self._tile_layout = tile_layout
@functools.partial(jax.named_call, name='SpatialDerivative')
def __call__(self, inputs, *auxiliary_inputs):
"""Computes spatial derivative of `inputs` evaluated at `offset`."""
# TODO(jamieas): consider moving this roll inside `extract_patches`. For the
# `roll` implementation of `extract_patches`, we can simply add it to the
# `shifts`. For the `conv` implementation, we may be able to effectively
# roll the input by adjusting how arrays are padded.
rolled = jnp.roll(inputs, self._roll)
patches = layers_util.extract_patches(
rolled, self._stencil_shape,
self._extract_patch_method, self._tile_layout)
if auxiliary_inputs is not None:
auxiliary_inputs = [jnp.roll(aux, self._roll) for aux in auxiliary_inputs]
rolled = jnp.concatenate([rolled, *auxiliary_inputs], axis=-1)
coefficients = self._coefficients_module(rolled)
return layers_util.apply_coefficients(coefficients, patches)
"""Tests for google3.research.simulation.whirl.layers."""
import functools
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.ml import layers
from jax_cfd.ml import layers_util
import numpy as np
KERNEL_SIZES = [3, 4]
ROLL_BY = [2, 4, 7]
RATES = [1, 2]
def conv_test_parameters(conv_modules, ndim, test_rate=True):
if test_rate:
product = itertools.product(conv_modules, KERNEL_SIZES, ROLL_BY, RATES)
else:
product = itertools.product(conv_modules, KERNEL_SIZES, ROLL_BY, [None])
parameters = []
for conv_module, kernel_size, roll_by, rate in product:
name = '_'.join([
conv_module.__name__, f'kernel_size_{kernel_size}',
f'rollby_{roll_by}', f'rate_{rate}'])
parameters.append(dict(
testcase_name=name,
conv_module=conv_module,
kernel_shape=(kernel_size,) * ndim,
roll_by=roll_by,
rate=rate))
return parameters
class ConvPeriodicTest(test_util.TestCase):
"""Tests all convolutions with periodic boundary conditions."""
@parameterized.named_parameters(
*(conv_test_parameters([layers.PeriodicConvTranspose1D], 1, False) +
conv_test_parameters([layers.PeriodicConv1D], 1, True)))
def test_equivariance_1d(self, conv_module, kernel_shape, roll_by, rate):
input_shape = (32, 1)
inputs = np.random.uniform(size=input_shape)
def net_forward(x):
if rate is not None:
module = conv_module(1, kernel_shape, rate=rate)
else:
module = conv_module(1, kernel_shape)
return module(x)
net = hk.without_apply_rng(hk.transform(net_forward))
rng = hk.PRNGSequence(42)
net_params = net.init(next(rng), inputs)
roll_conv = jnp.roll(net.apply(net_params, inputs), roll_by, 0)
conv_roll = net.apply(net_params, jnp.roll(inputs, roll_by, 0))
np.testing.assert_allclose(roll_conv, conv_roll)
@parameterized.named_parameters(
*(conv_test_parameters([layers.PeriodicConvTranspose2D], 2, False) +
conv_test_parameters([layers.PeriodicConv2D], 2, True)))
def test_equivariance_2d(self, conv_module, kernel_shape, roll_by, rate):
input_shape = (32, 32, 1)
inputs = np.random.uniform(size=input_shape)
def net_forward(x):
if rate is not None:
module = conv_module(1, kernel_shape, rate=rate)
else:
module = conv_module(1, kernel_shape)
return module(x)
net = hk.without_apply_rng(hk.transform(net_forward))
rng = hk.PRNGSequence(42)
net_params = net.init(next(rng), inputs)
roll_conv_x = jnp.roll(net.apply(net_params, inputs), roll_by, 0)
conv_roll_x = net.apply(net_params, jnp.roll(inputs, roll_by, 0))
roll_conv_y = jnp.roll(net.apply(net_params, inputs), roll_by, 1)
conv_roll_y = net.apply(net_params, jnp.roll(inputs, roll_by, 1))
np.testing.assert_allclose(roll_conv_x, conv_roll_x)
np.testing.assert_allclose(roll_conv_y, conv_roll_y)
@parameterized.named_parameters(
*(conv_test_parameters([layers.PeriodicConvTranspose3D], 3, False) +
conv_test_parameters([layers.PeriodicConv3D], 3, True)))
def test_equivariance_3d(self, conv_module, kernel_shape, roll_by, rate):
input_shape = (16, 16, 16, 1)
inputs = np.random.uniform(size=input_shape)
def net_forward(x):
if rate is not None:
module = conv_module(1, kernel_shape, rate=rate)
else:
module = conv_module(1, kernel_shape)
return module(x)
net = hk.without_apply_rng(hk.transform(net_forward))
rng = hk.PRNGSequence(42)
net_params = net.init(next(rng), inputs)
roll_conv_x = jnp.roll(net.apply(net_params, inputs), roll_by, 0)
conv_roll_x = net.apply(net_params, jnp.roll(inputs, roll_by, 0))
roll_conv_y = jnp.roll(net.apply(net_params, inputs), roll_by, 1)
conv_roll_y = net.apply(net_params, jnp.roll(inputs, roll_by, 1))
roll_conv_z = jnp.roll(net.apply(net_params, inputs), roll_by, 2)
conv_roll_z = net.apply(net_params, jnp.roll(inputs, roll_by, 2))
np.testing.assert_allclose(roll_conv_x, conv_roll_x)
np.testing.assert_allclose(roll_conv_y, conv_roll_y)
np.testing.assert_allclose(roll_conv_z, conv_roll_z)
@parameterized.named_parameters(
{
'testcase_name': '1d',
'conv_module': layers.PeriodicConv1D,
'input_shape': (32, 3),
'kwargs': dict(output_channels=3, kernel_shape=(5,)),
'tile_layout': (4,),
},
{
'testcase_name': '2d_3x3',
'conv_module': layers.PeriodicConv2D,
'input_shape': (8, 8, 3),
'kwargs': dict(output_channels=3, kernel_shape=(3, 3)),
'tile_layout': (2, 4),
},
{
'testcase_name': '2d_4x4',
'conv_module': layers.PeriodicConv2D,
'input_shape': (8, 8, 3),
'kwargs': dict(output_channels=3, kernel_shape=(4, 4)),
'tile_layout': (2, 1),
},
{
'testcase_name': '3d',
'conv_module': layers.PeriodicConv3D,
'input_shape': (8, 8, 8, 3),
'kwargs': dict(output_channels=3, kernel_shape=(3, 3, 3)),
'tile_layout': (2, 4, 4),
},
)
def test_tile_layout(self, conv_module, input_shape, kwargs, tile_layout):
# pylint: disable=unnecessary-lambda
inputs = np.random.uniform(size=input_shape)
untiled_layout = (1,) * len(tile_layout)
base_module = lambda x: conv_module(**kwargs, tile_layout=untiled_layout)(x)
base_net = hk.without_apply_rng(hk.transform(base_module))
tiled_module = lambda x: conv_module(**kwargs, tile_layout=tile_layout)(x)
tiled_net = hk.without_apply_rng(hk.transform(tiled_module))
params = base_net.init(jax.random.PRNGKey(42), inputs)
base_out = base_net.apply(params, inputs)
tiled_out = tiled_net.apply(params, inputs)
np.testing.assert_allclose(base_out, tiled_out, atol=1e-6)
@parameterized.named_parameters([
('size_60_stride_1', 60, 1),
('size_60_stride_2', 60, 2),
('size_60_stride_3', 60, 3),
('size_60_stride_5', 60, 5),
('size_45_stride_3', 45, 3),
])
def test_roundtrip_spatial_alignment(self, input_size, stride):
"""Tests that ConvTansposed(Conv(x)) with identity params is identity op."""
input_shape = (input_size, 1)
inputs = np.random.uniform(size=input_shape)
w_init = hk.initializers.Constant(
np.reshape(np.asarray([0., 1., 0.]), (3, 1, 1)))
b_init = hk.initializers.Constant(np.zeros((1,)))
def net_forward(x):
conv_args = {
'output_channels': 1,
'kernel_shape': (3,),
'stride': stride,
'w_init': w_init,
'b_init': b_init,
}
conv = layers.PeriodicConv1D(**conv_args)
conv_transpose = layers.PeriodicConvTranspose1D(**conv_args)
return conv_transpose(conv(x))
net = hk.without_apply_rng(hk.transform(net_forward))
rng = hk.PRNGSequence(42)
net_params = net.init(next(rng), inputs)
output = net.apply(net_params, inputs)
stride_mask = np.expand_dims(np.asarray([1] + [0] * (stride -1)), -1)
mask = np.tile(stride_mask, (input_shape[0] // stride, 1))
expected_output = inputs * mask
np.testing.assert_allclose(output, expected_output)
class RescaleToRangeTest(test_util.TestCase):
"""Tests `rescale_to_range` layer."""
@parameterized.named_parameters([
('rescale_1d', (32,)),
('rescale_2d', (16, 24)),
('rescale_3d', (8, 16, 16)),
])
def test_min_max_shape(self, shape):
"""Tests that rescaled values have expected shapes and min/max values."""
min_value = -0.4
max_value = 0.73
axes = tuple(np.arange(len(shape)))
input_values = np.random.uniform(low=-5., high=5., size=shape)
rescale = functools.partial(
layers.rescale_to_range, min_value=min_value, max_value=max_value,
axes=axes)
output = rescale(input_values)
actual_max = jnp.max(output)
actual_min = jnp.min(output)
self.assertEqual(shape, output.shape) # shape shouldn't change
self.assertAllClose(min_value, actual_min, atol=1e-6)
self.assertAllClose(max_value, actual_max, atol=1e-6)
@parameterized.named_parameters([
('rescale_1d', (32,)),
('rescale_2d', (16, 24)),
('rescale_3d', (8, 16, 16)),
])
def test_correctness(self, shape):
"""Tests that rescaled values belong to expected range."""
min_value = 0.
max_value = 1.
axes = tuple(np.arange(len(shape)))
num_elements = np.prod(shape)
input_values = np.random.uniform(low=-5., high=5., size=num_elements)
input_values[0] = -10.
input_values[-1] = 10.
input_values = np.reshape(input_values, newshape=shape)
rescale = functools.partial(
layers.rescale_to_range, min_value=min_value, max_value=max_value,
axes=axes)
# we can also not even call net_init, since no parameters are needed.
actual_output = rescale(input_values)
expected_output = (input_values + 10.) / 20.
self.assertAllClose(expected_output, actual_output)
def _name_test(ndim, stencil, derivs):
return '{}d_stencil_{}_derivatives_{}'.format(ndim, stencil, derivs)
TESTS_1D = [
(_name_test(1, stencil, derivs), (32,), stencil, derivs)
for stencil, derivs in itertools.product([[3], [4]], [[0], [1]])
]
TESTS_2D = [
(_name_test(2, stencil, derivs), (32, 32), stencil, derivs)
for stencil, derivs in itertools.product([[3, 3], [4, 4]], [[0, 0], [1, 1]])
]
TESTS_3D = [
(_name_test(3, stencil, derivs), (16, 16, 16), stencil, derivs)
for stencil, derivs in itertools.product([[3, 3, 3], [4, 4, 4]],
[[0, 0, 0], [1, 1, 0]])
]
TESTS_ALL = TESTS_1D + TESTS_2D + TESTS_3D
def _make_test_stencil(size, step):
return np.array([i * step for i in range(-size // 2 + 1, size // 2 + 1)])
class PolynomialConstraintTest(test_util.TestCase):
"""Tests `PolynomialConstraint` module."""
@parameterized.named_parameters(
dict(testcase_name=name, # pylint: disable=g-complex-comprehension
grid_shape=shape,
stencil_sizes=stencil,
derivative_orders=derivs)
for name, shape, stencil, derivs in TESTS_ALL
)
def test_shapes(self, grid_shape, stencil_sizes, derivative_orders):
"""Tests that PolynomialConstraint returns expected shapes."""
ndims = len(grid_shape)
grid_step = 0.1
steps = (grid_step,) * ndims
stencils = [_make_test_stencil(size, grid_step) for size in stencil_sizes]
method = layers_util.Method.FINITE_VOLUME
module = layers.PolynomialConstraint(
stencils, derivative_orders, method, steps)
inputs = np.random.uniform(size=grid_shape + (module.subspace_size,))
outputs = module(inputs)
actual_shape = outputs.shape
expected_shape = grid_shape + (np.prod(stencil_sizes),)
self.assertEqual(actual_shape, expected_shape)
@parameterized.named_parameters(
dict(testcase_name=name, # pylint: disable=g-complex-comprehension
grid_shape=shape,
stencil_sizes=stencil,
derivative_orders=derivs)
for name, shape, stencil, derivs in TESTS_ALL
)
def test_values(self, grid_shape, stencil_sizes, derivative_orders):
"""Tests that result of PolynomialConstraint satisfies poly-constraints."""
ndims = len(grid_shape)
grid_step = 0.1
steps = (grid_step,) * ndims
stencils = [_make_test_stencil(size, grid_step) for size in stencil_sizes]
method = layers_util.Method.FINITE_VOLUME
module = layers.PolynomialConstraint(
stencils, derivative_orders, method, steps)
inputs = np.random.uniform(size=grid_shape + (module.subspace_size,))
outputs = module(inputs)
a, b = layers_util.polynomial_accuracy_constraints(
stencils, method, derivative_orders, 1, grid_step)
violation = jnp.transpose(jnp.tensordot(a, outputs, axes=[-1, -1])) - b
np.testing.assert_allclose(jnp.max(violation), 0., atol=1e-2)
def _tower_factory(num_output_channels, ndims, conv_block):
rescale_01 = functools.partial(layers.rescale_to_range, min_value=0.,
max_value=1., axes=list(range(ndims)))
components = [rescale_01]
for output_channels, kernel_shape in zip([4], [[3]* ndims]):
components.append(conv_block(output_channels, kernel_shape))
components.append(jax.nn.relu)
components.append(conv_block(num_output_channels, [3] * ndims))
return hk.Sequential(components, name='tower')
class StencilCoefficientsTest(test_util.TestCase):
"""Tests StencilCoefficients module."""
@parameterized.named_parameters([
('1d', (32,), (1,), (4,), layers.PeriodicConv1D, 1),
('2d', (16, 21), (0, 1), (5, 5), layers.PeriodicConv2D, 2),
('3d', (8, 8, 6), (0, 0, 0), (2, 2, 2), layers.PeriodicConv3D, 3),
])
def test_output_shape(self, input_shape, derivative_orders, stencil_sizes,
conv_block, ndims):
grid_step = 0.1
steps = (grid_step,) * ndims
stencils = [_make_test_stencil(size, grid_step) for size in stencil_sizes]
tower_factory = functools.partial(_tower_factory, ndims=ndims,
conv_block=conv_block)
def compute_coefficients(inputs):
net = layers.StencilCoefficients(
stencils, derivative_orders, tower_factory, steps)
return net(inputs)
coefficients_model = hk.without_apply_rng(
hk.transform(compute_coefficients))
rng = hk.PRNGSequence(42)
inputs = np.random.uniform(size=input_shape + (1,))
params = coefficients_model.init(next(rng), inputs)
outputs = coefficients_model.apply(params, inputs)
actual_shape = outputs.shape
expected_shape = input_shape + (np.prod(stencil_sizes),)
self.assertEqual(actual_shape, expected_shape)
class SpatialDerivativeFromLogitsTest(test_util.TestCase):
"""Tests SpatialDerivativeFromLogits module."""
@parameterized.named_parameters(
dict(testcase_name='1D',
derivative_orders=(0,),
input_shape=(256,),
input_offset=(.5,),
target_offset=(1,),
steps=(1,),
stencil_shape=(5,),
tile_layout=(4,)),
dict(testcase_name='2D',
derivative_orders=(0, 1),
input_shape=(64, 64),
input_offset=(0, 0),
target_offset=(10, 0),
steps=(.1, .1),
stencil_shape=(4, 4,),
tile_layout=None),
dict(testcase_name='3D',
derivative_orders=(0, 1, 0),
input_shape=(32, 64, 16),
input_offset=(.5, .5, .5),
target_offset=(0, 0, 0),
steps=(3, 3, 3),
stencil_shape=(3, 3, 3),
tile_layout=(8, 8, 8)),
)
def test_shape(self, derivative_orders, input_shape, input_offset,
target_offset, steps, stencil_shape, tile_layout):
inputs = jnp.ones(input_shape)
for extract_patches_method in ('conv', 'roll'):
with self.subTest(f'method_{extract_patches_method}'):
derivative_from_logits = layers.SpatialDerivativeFromLogits(
stencil_shape, input_offset, target_offset, derivative_orders,
steps, extract_patches_method, tile_layout)
logits = jnp.ones(
input_shape + (derivative_from_logits.subspace_size,))
derivative = derivative_from_logits(jnp.expand_dims(inputs, -1), logits)
expected_shape = input_shape + (1,)
self.assertArrayEqual(expected_shape, derivative.shape)
class SpatialDerivativeTest(test_util.TestCase):
"""Tests SpatialDerivative module."""
@parameterized.named_parameters([
('interpolation', (256,), (0,), np.sin, np.sin, 1e-1),
('first_derivative', (256,), (1,), np.sin, np.cos, 1e-1),
('second_derivative', (256,), (2,), np.sin, lambda x: -np.sin(x), 1e-1),
])
def test_1d(self, grid_shape, derivative, initial_fn, expected_fn, atol):
"""Tests SpatialDerivative module in 1d."""
ndims = len(grid_shape)
grid = grids.Grid(grid_shape, domain=tuple([(0., 2. * np.pi) * ndims]))
stencil_shape = (4,) * ndims
tower_factory = functools.partial(
_tower_factory, ndims=ndims, conv_block=layers.PeriodicConv1D)
for extract_patches_method in ('conv', 'roll'):
with self.subTest(f'method_{extract_patches_method}'):
def module_forward(x):
net = layers.SpatialDerivative(
stencil_shape, grid.cell_center, grid.cell_faces[0], derivative,
tower_factory, grid.step, extract_patches_method) # pylint: disable=cell-var-from-loop
return net(x)
rng = jax.random.PRNGKey(41)
spatial_derivative_model = hk.without_apply_rng(
hk.transform(module_forward))
x, = grid.mesh()
x_target, = grid.mesh(offset=grid.cell_faces[0])
inputs = jnp.expand_dims(initial_fn(x), -1) # add channel dimension
params = spatial_derivative_model.init(rng, inputs)
outputs = spatial_derivative_model.apply(params, inputs)
expected_outputs = np.expand_dims(expected_fn(x_target), -1)
np.testing.assert_allclose(expected_outputs, outputs, atol=atol, rtol=0)
@parameterized.named_parameters([
('interpolation', (128, 128), (0, 0),
lambda x, y: np.sin(2 * x + y), lambda x, y: np.sin(2 * x + y), 0.2),
('first_derivative_x', (128, 128), (1, 0),
lambda x, y: np.cos(2 * x + y), lambda x, y: -2 * np.sin(2 * x + y),
0.1),
])
def test_2d(self, grid_shape, derivative, initial_fn, expected_fn, atol):
"""Tests SpatialDerivative module in 2d."""
ndims = len(grid_shape)
grid = grids.Grid(grid_shape, domain=tuple([(0., 2. * np.pi)] * ndims))
stencil_sizes = (3,) * ndims
tower_factory = functools.partial(
_tower_factory, ndims=ndims, conv_block=layers.PeriodicConv2D)
for extract_patches_method in ('conv', 'roll'):
with self.subTest(f'method_{extract_patches_method}'):
def module_forward(inputs):
net = layers.SpatialDerivative(
stencil_sizes, grid.cell_center, grid.cell_center, derivative,
tower_factory, grid.step, extract_patches_method) # pylint: disable=cell-var-from-loop
return net(inputs)
rng = jax.random.PRNGKey(14)
spatial_derivative_model = hk.without_apply_rng(
hk.transform(module_forward))
x, y = grid.mesh()
inputs = np.expand_dims(initial_fn(x, y), -1) # add channel dimension
params = spatial_derivative_model.init(rng, inputs)
outputs = spatial_derivative_model.apply(params, inputs)
expected_outputs = np.expand_dims(expected_fn(x, y), -1)
np.testing.assert_allclose(
expected_outputs, outputs, atol=atol, rtol=0,
err_msg=f'Failed for method "{extract_patches_method}"')
def test_auxiliary_inputs(self):
"""Tests that auxiliary inputs don't change shape of the output."""
grid = grids.Grid((64,), domain=tuple([(0., 2. * np.pi)]))
stencil_sizes = (3,)
tower_factory = functools.partial(
_tower_factory, ndims=1, conv_block=layers.PeriodicConv1D)
def module_forward(inputs, *auxiliary_inputs):
net = layers.SpatialDerivative(
stencil_sizes, grid.cell_center, grid.cell_center, (1,),
tower_factory, grid.step)
return net(inputs, *auxiliary_inputs)
rng = jax.random.PRNGKey(14)
spatial_derivative_model = hk.without_apply_rng(
hk.transform(module_forward))
inputs = np.expand_dims(grid.mesh()[0], -1) # add channel dimension
auxiliary_inputs = np.ones((64, 1))
params = spatial_derivative_model.init(rng, inputs, auxiliary_inputs)
outputs = spatial_derivative_model.apply(params, inputs, auxiliary_inputs)
self.assertEqual(outputs.shape, (64, 1))
if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
absltest.main()
"""Utility functions for layers.py."""
import enum
import functools
import itertools
import math
from typing import Iterator, Optional, Sequence, Tuple, Union
import jax
from jax import lax
import jax.numpy as jnp
from jax_cfd.ml import tiling
import numpy as np
import scipy.special
Array = Union[np.ndarray, jax.Array]
class Method(enum.Enum):
"""Discretization method."""
FINITE_DIFFERENCE = 1
FINITE_VOLUME = 2
def _kronecker_product(arrays: Sequence[np.ndarray]) -> np.ndarray:
"""Returns a kronecker product of a sequence of arrays."""
return functools.reduce(np.kron, arrays)
def _exponents_up_to_degree(
degree: int,
num_dimensions: int
) -> Iterator[Tuple[int, ...]]:
"""Generate all exponents up to given degree.
Args:
degree: a non-negative integer representing the maximum degree.
num_dimensions: a non-negative integer representing the number of
dimensions.
Yields:
An iterator over all tuples of non-negative integers of length
`num_dimensions`, whose sum is at most `degree`.
Example:
For degree=2 and num_dimensions=2, this iterates through
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (2, 0)].
"""
if num_dimensions == 0:
yield tuple()
else:
for d in range(degree + 1):
for exponents in _exponents_up_to_degree(degree - d, num_dimensions - 1):
yield (d,) + exponents
def polynomial_accuracy_constraints(
stencils: Sequence[np.ndarray],
method: Method,
derivative_orders: Sequence[int],
accuracy_order: int,
grid_step: Optional[float] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Setup a linear equation A @ c = b for finite difference coefficients.
Elements are returned in row-major order, e.g., if two stencils of length 2
are provided, then coefficients for the stencil is organized as:
[s_00, s_01, s_10, s_11], where s_ij corresponds to a value at
(stencil_1[i], stencil_2[j]). The returned constraints assume that
coefficients aim to appoximate the derivative of `derivative_order` at (0, 0).
Example:
For stencils = [np.array([-0.5, 0.5])] * 2, coefficients approximate the
derivative at point (0., 0.) using values located at (-0.5, -0.5),
(-0.5, 0.5), (0.5, 0.5), (0.5, 0.5)
Args:
stencils: list of arrays giving 1D stencils in each direction.
method: discretization method (i.e., finite volumes or finite differences).
derivative_orders: integer derivative orders to approximate in each grid
direction.
accuracy_order: minimum accuracy orders for the solution in each grid
direction.
grid_step: spacing between grid cells.
Returns:
Tuple of arrays `(A, b)` where `A` is 2D and `b` is 1D providing linear
constraints. Any vector of finite difference coefficients `c` such that
`A @ c = b` satisfies the requested accuracy order. The matrix `A` is
guaranteed not to have more rows than columns.
Raises:
ValueError: if the linear constraints are not satisfiable.
References:
https://en.wikipedia.org/wiki/Finite_difference_coefficient
Fornberg, Bengt (1988), "Generation of Finite Difference Formulas on
Arbitrarily Spaced Grids", Mathematics of Computation, 51 (184): 699-706,
doi:10.1090/S0025-5718-1988-0935077-0, ISSN 0025-5718.
"""
if len(stencils) != len(derivative_orders):
raise ValueError('mismatched lengths for stencils and derivative_orders')
if accuracy_order < 1:
raise ValueError('cannot compute constriants with non-positive '
'accuracy_order: {}'.format(accuracy_order))
all_constraints = {}
# See http://g3doc/third_party/py/datadrivenpdes/g3doc/polynomials.md.
num_dimensions = len(stencils)
max_degree = accuracy_order + sum(derivative_orders) - 1
for exponents in _exponents_up_to_degree(max_degree, num_dimensions):
# build linear constraints for a single polynomial term:
# \prod_i {x_i}^{m_i}
lhs_terms = []
rhs_terms = []
exp_stencil_derivative = zip(exponents, stencils, derivative_orders)
for exponent, stencil, derivative_order in exp_stencil_derivative:
if method is Method.FINITE_VOLUME:
if grid_step is None:
raise ValueError('grid_step is required for finite volumes')
# average value of x**m over a centered grid cell
lhs_terms.append(
1 / grid_step * ((stencil + grid_step / 2)**(exponent + 1) -
(stencil - grid_step / 2)**(exponent + 1)) /
(exponent + 1))
elif method is Method.FINITE_DIFFERENCE:
lhs_terms.append(stencil**exponent)
else:
raise ValueError('unexpected method: {}'.format(method))
if exponent == derivative_order:
# we get a factor of m! for m-th order derivative in each direction
rhs_term = scipy.special.factorial(exponent)
else:
rhs_term = 0
rhs_terms.append(rhs_term)
lhs = tuple(_kronecker_product(lhs_terms))
rhs = np.prod(rhs_terms)
if lhs in all_constraints and all_constraints[lhs] != rhs:
raise ValueError('conflicting constraints')
all_constraints[lhs] = rhs
lhs_rows, rhs_rows = zip(*sorted(all_constraints.items()))
A = np.array(lhs_rows) # pylint: disable=invalid-name
b = np.array(rhs_rows)
return A, b
def _high_order_coefficients_1d(
stencil: np.ndarray,
method: Method,
derivative_order: int,
grid_step: Optional[float] = None,
) -> np.ndarray:
"""Calculate highest-order coefficients that appoximate `derivative_order`.
Args:
stencil: 1D array representing locations of the stencil's cells.
method: discretization method (i.e., finite volumes or finite differences).
derivative_order: derivative order being approximated.
grid_step: grid step size.
Returns:
Array representing stencil coefficients that approximate `derivative_order`
spatial derivative on provided `stencil`.
"""
# Use the highest order accuracy we can ensure in general. (In some cases,
# e.g., centered finite differences, this solution actually has higher order
# accuracy.)
accuracy_order = stencil.size - derivative_order
A, b = polynomial_accuracy_constraints( # pylint: disable=invalid-name
[stencil], method, [derivative_order], accuracy_order, grid_step)
return np.linalg.solve(A, b)
def polynomial_accuracy_coefficients(
stencils: Sequence[np.ndarray],
method: Method,
derivative_orders: Sequence[int],
accuracy_order: Optional[int] = None,
grid_step: Optional[float] = None,
) -> np.ndarray:
"""Calculate standard finite volume coefficients.
These coefficients are constructed by taking an outer product of coefficients
along each dimension independently. The resulting coefficients have *at least*
the requested accuracy order. The derivative is approximated at `0.` position
along each stencil.
Args:
stencils: sequence of 1d stencils, one per grid dimension.
method: discretization method (i.e., finite volumes or finite differences).
derivative_orders: integer derivative orders to approximate, per grid
dimension.
accuracy_order: accuracy order for the solution. By default, the highest
possible accuracy is used in each direction.
grid_step: spacing between grid cells. Required if calculating a finite
volume stencil.
Returns:
NumPy array with one-dimension per stencil giving first order finite
difference coefficients on the grid.
"""
slices = []
sizes = []
all_coefficients = []
for stencil, derivative_order in zip(stencils, derivative_orders):
if accuracy_order is None:
excess = 0
else:
excess = stencil.size - derivative_order - accuracy_order
start = excess // 2
stop = stencil.size - excess // 2
slice_ = slice(start, stop)
axis_coefficients = _high_order_coefficients_1d(
stencil[slice_], method, derivative_order, grid_step)
slices.append(slice_)
sizes.append(stencil[slice_].size)
all_coefficients.append(axis_coefficients)
result = np.zeros(tuple(stencil.size for stencil in stencils))
result[tuple(slices)] = _kronecker_product(all_coefficients).reshape(sizes)
return result
def get_roll_and_shift(
input_offset: Tuple[float, ...],
target_offset: Tuple[float, ...]
) -> Tuple[Tuple[int, ...], Tuple[float, ...]]:
"""Decomposes delta as integer `roll` and positive fractional `shift`."""
delta = [t - i for t, i in zip(target_offset, input_offset)]
roll = tuple(-math.floor(d) for d in delta)
shift = tuple(d + r for d, r in zip(delta, roll))
return roll, shift
def get_stencils(
stencil_sizes: Tuple[int, ...],
offset: Tuple[float, ...],
steps: Tuple[float, ...]
) -> Tuple[np.ndarray]:
"""Computes stencils locations.
Generates stencils such that the target offset is placed at (0.,)*ndims. This
is needed to obtain correct polynomial constraints. This approach places
equal number of cells on the right and left for half-cell difference for even
stencils and for same offset for odd stencils. Otherwise adds an extra cell on
one of the sides. The order of the returned `stencil_shifts` is row-major,
i.e. an outer product of shifts along axes.
Args:
stencil_sizes: sizes of 1d stencils along each directions.
offset: the target offset relative to the current offset i.e
`target_offset - input_offset`.
steps: distances between adjacent grid cells.
Returns:
stencils: list of 1d stencils representing locations of the centers of cells
of inputs array.
"""
stencils = []
for size, o, step, in zip(stencil_sizes, offset, steps):
left = -((size - 1) // 2)
shifts = range(left, left + size)
stencils.append(np.array([(-o + s) * step for s in shifts]))
return tuple(stencils)
def _get_padding(
kernel_shape: Tuple[int, ...]
) -> Tuple[Tuple[int, int], ...]:
"""Returns the padding for convolutions used in `extract_patches`.
Note that the padding here is "flipped" compared to the padding used in
`PeriodicConvGeneral`.
Args:
kernel_shape: the shape of the convolutional kernel.
Returns:
A tuple of pairs of ints. Each pair indicates the padding that should be
added before and after the array for a periodic convolution with shape
`kernel_shape`.
"""
# TODO(jamieas): use this function to compute padding for
# `PeriodicConvGeneral`.
padding = []
for kernel_size in kernel_shape[:-2]:
pad_right = kernel_size // 2
pad_left = kernel_size - pad_right - 1
padding.append((pad_left, pad_right))
return tuple(padding)
_DIMENSION_NUMBERS = {
1: ('NWC', 'WIO', 'NWC'),
2: ('NHWC', 'HWIO', 'NHWC'),
3: ('NHWDC', 'HWDIO', 'NHWDC'),
}
PrecisionLike = Optional[Union[lax.Precision, Tuple[lax.Precision,
lax.Precision]]]
def periodic_convolution(
x: Array,
kernel: Array,
tile_layout: Optional[Tuple[int, ...]] = None,
precision: PrecisionLike = lax.Precision.HIGHEST,
) -> Array:
"""Applies a periodic convolution."""
num_spatial_dims = kernel.ndim - 2
padding = _get_padding(kernel.shape)
strides = [1] * num_spatial_dims
dimension_numbers = _DIMENSION_NUMBERS[num_spatial_dims]
conv = functools.partial(jax.lax.conv_general_dilated,
rhs=kernel,
window_strides=strides,
padding='VALID',
dimension_numbers=dimension_numbers,
precision=precision)
return tiling.apply_convolution(conv, x, layout=tile_layout, padding=padding)
# Caching the result of _patch_kernel() ensures that only one constant value is
# used as a side-input into JAX's jit/pmap. This ensures that XLA's Common
# Subexpression Elimination (CSE) pass can consolidate calls to extract patches
# on the same array.
@functools.lru_cache()
def _patch_kernel( # pytype: disable=annotation-type-mismatch # numpy-scalars
patch_shape: Tuple[int, ...],
dtype: np.dtype = np.float32
) -> np.ndarray:
"""Returns a convolutional kernel that extracts patches."""
patch_size = np.prod(patch_shape)
kernel_2d = np.eye(patch_size, dtype=dtype)
kernel_shape = (patch_size, 1) + patch_shape
kernel_nd = kernel_2d.reshape(kernel_shape)
return np.moveaxis(kernel_nd, (0, 1), (-1, -2))
@functools.partial(jax.jit, static_argnums=(1,))
def _extract_patches_roll(
x: Array,
patch_shape: Tuple[int, ...]
) -> Array:
"""Extract patches of the given shape using a vmapped `roll` operation."""
# Computes shifts required for the given `patch_shape`.
x = jnp.squeeze(x, -1)
shifts = []
for size in patch_shape:
shifts.append(range(-size // 2 + 1, size // 2 + 1))
rolls = -np.stack(tuple(itertools.product(*shifts)))
out_axis = x.ndim
roll_axes = range(out_axis)
in_axes = (None, 0, None)
return jax.vmap(jnp.roll, in_axes, out_axis)(x, rolls, roll_axes)
@functools.partial(jax.jit, static_argnums=(1, 2))
def _extract_patches_conv(
x: Array,
patch_shape: Tuple[int, ...],
tile_layout: Optional[Tuple[int, ...]],
) -> Array:
"""Extract patches of the given shape using a tiled convolution."""
kernel = _patch_kernel(patch_shape, dtype=x.dtype)
# the kernel can be represented exactly in bfloat16
precision = (lax.Precision.HIGHEST, lax.Precision.DEFAULT)
return periodic_convolution(x, kernel, tile_layout, precision=precision)
def extract_patches(
x: Array,
patch_shape: Tuple[int, ...],
method: str = 'roll',
tile_layout: Optional[Tuple[int, ...]] = None):
"""Extracts patches of given shape, stacks them along the channel dimension.
For example,
```
x = [[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
x = np.expand_dims(x, -1) # Add 'channel' dimension.
y = extract_patches(x, [3, 3])
y[0, 0] # [15, 12, 13, 3, 0, 1, 7, 4, 5]
y[1, 1] # [0, 1, 2, 4, 5, 6, 7, 8, 9]
z = extract_patches(x, [2, 2])
z[0, 0] # [0, 1, 4, 5]
z[2, 2] # [10, 11, 14, 15]
```
In particular, for even patch sizes, `extract_patches` includes extra values
with _higher_ indices.
Args:
x: the array with shape [d0, ..., dk] from which we will extract patches.
patch_shape: a tuple (p0, ..., pk) describing the shape of the patches to
extract.
method: determines which method is used for extracting patches. Must be
either 'roll' or 'conv'.
tile_layout: an optional tuple (t0, ..., tk) describing the tiling that will
be used to perform the convolutions that extract patches. If `None`, then
no tiling is performed. If `method == 'roll'`, this argument has not
effect.
Returns:
An array of shape [d0, ..., dk, c] where `c = prod(patch_shape)`.
"""
# TODO(jamieas): consider removing the 'roll' method once the convolutional
# one has been optimized.
if method == 'roll':
return _extract_patches_roll(x, patch_shape)
elif method == 'conv':
return _extract_patches_conv(x, patch_shape, tile_layout)
else:
raise ValueError(f'Unknown `method` passed to `extract_patches`: {method}.')
def fused_extract_patches(
x: Array,
patch_shapes: Sequence[Tuple[int, ...]],
tile_layout: Optional[Tuple[int, ...]] = None,
):
kernel = np.concatenate(
[_patch_kernel(s, dtype=x.dtype) for s in patch_shapes], axis=-1)
# the kernel can be represented exactly in bfloat16
precision = (lax.Precision.HIGHEST, lax.Precision.DEFAULT)
return periodic_convolution(x, kernel, tile_layout, precision=precision)
# TODO(dkochkov) consider alternative ops for better efficiency on MXU.
def apply_coefficients(coefficients, stencil_values):
"""Constructs array as a weighted sum of `stencil_values`."""
return jnp.sum(coefficients * stencil_values, axis=-1, keepdims=True)
"""Tests for google3.research.simulation.whirl.layers_util."""
import itertools
from absl.testing import absltest
from absl.testing import parameterized
from jax_cfd.base import test_util
from jax_cfd.ml import layers_util
import numpy as np
FINITE_DIFF = layers_util.Method.FINITE_DIFFERENCE
FINITE_VOL = layers_util.Method.FINITE_VOLUME
def _stencil_id(stencil_coordinates, stencil_sizes):
"""Computes id of in the stencil basis given stencil_coordinates."""
axes_shifts = [1]
for stencil_size in stencil_sizes:
axes_shifts.append(axes_shifts[-1] * stencil_size)
axes_shifts = np.array(axes_shifts[:-1])
return np.sum(np.array(stencil_coordinates[::-1]) * axes_shifts)
class HelperFunctionsTest(test_util.TestCase):
"""Tests helper functions in layers_util."""
def test_exponents_up_to_degree(self):
exponents_iterator = layers_util._exponents_up_to_degree(2, 2)
expected_values = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (2, 0)]
actual_values = list(exponents_iterator)
self.assertEqual(expected_values, actual_values)
class PolynomialAccuracyConstraintsTest(test_util.TestCase):
"""Tests polynomial accuracy constraints utils."""
@parameterized.parameters(
dict(
accuracy_order=1,
method=FINITE_DIFF,
expected_a=[[1, 1]],
expected_b=[1]),
dict(
accuracy_order=1,
method=FINITE_VOL,
expected_a=[[1, 1]],
expected_b=[1]),
dict(
accuracy_order=2,
method=FINITE_DIFF,
expected_a=[[-1 / 2, 1 / 2], [1, 1]],
expected_b=[0, 1]),
dict(
accuracy_order=2,
method=FINITE_VOL,
expected_a=[[-1 / 2, 1 / 2], [1, 1]],
expected_b=[0, 1]),
)
def test_constraints_1d(self, accuracy_order, method, expected_a, expected_b):
a, b = layers_util.polynomial_accuracy_constraints(
[np.array([-0.5, 0.5])], method, derivative_orders=[0],
accuracy_order=accuracy_order, grid_step=1.0)
np.testing.assert_allclose(a, expected_a)
np.testing.assert_allclose(b, expected_b)
def test_constraints_2d_second_order_zeroth_derivative(self):
# these constraints should be under-determined.
stencils = [np.array([-0.5, 0.5])] * 2
a, b = layers_util.polynomial_accuracy_constraints(
stencils,
FINITE_DIFF,
derivative_orders=[0, 0],
accuracy_order=2)
# three constraints, for each term in [1, x, y]
self.assertEqual(a.shape, (3, 4))
self.assertEqual(b.shape, (3,))
# explicitly test two valid solutions
np.testing.assert_allclose(a.dot([1 / 4, 1 / 4, 1 / 4, 1 / 4]), b)
np.testing.assert_allclose(a.dot([4 / 10, 1 / 10, 1 / 10, 4 / 10]), b)
def test_constraints_2d_first_order_first_derivative(self):
stencils = [np.array([-1, 0, 1])] * 2
a, b = layers_util.polynomial_accuracy_constraints(
stencils,
FINITE_DIFF,
derivative_orders=[1, 0],
accuracy_order=1)
# three constraints, for each term in [1, x, y]
self.assertEqual(a.shape, (3, 9))
self.assertEqual(b.shape, (3,))
# explicitly test a valid solution
solution = np.array([-1, 0, -1, 0, 0, 0, 1, 0, 1]) / 4
np.testing.assert_allclose(a.dot(solution), b)
# explicitly test an invalid solution.
# this solution is invalid because the stencil is a linear combination
# of derivatives in both the x and y directions.
non_solution = np.array([-1, 0, -1, 1, 0, -1, 1, 0, 1]) / 4
self.assertGreater(np.linalg.norm(a.dot(non_solution) - b), 0.1)
def test_constraints_2d_first_order_second_derivative(self):
stencils = [np.array([-1, 0, 1])] * 2
a, b = layers_util.polynomial_accuracy_constraints(
stencils,
FINITE_DIFF,
derivative_orders=[1, 1],
accuracy_order=1)
# six constraints, for each term in [1, x, y, x^2, xy, y^2]
self.assertEqual(a.shape, (6, 9))
self.assertEqual(b.shape, (6,))
# explicitly test a valid solution
solution = np.array([1, 0, -1, 0, 0, 0, -1, 0, 1]) / 4
np.testing.assert_allclose(a.dot(solution), b)
def test_constraints_3d_second_order_zeroth_derivative(self):
stencils = [np.array([-1, 0, 1])] * 3
a, b = layers_util.polynomial_accuracy_constraints(
stencils,
FINITE_DIFF,
derivative_orders=[0, 0, 0],
accuracy_order=2)
# four constraints, for each term in [1, x, y, z]
self.assertEqual(a.shape, (4, 27))
self.assertEqual(b.shape, (4,))
# explicitly test a valid solution
stencil = list(itertools.product(*stencils))
solution_a = np.zeros(27)
solution_a[stencil.index((0, 0, 0))] = 1.
solution_b = np.zeros(27)
solution_b[stencil.index((0, -1, 0))] = 0.25
solution_b[stencil.index((0, 0, 0))] = 0.5
solution_b[stencil.index((0, 1, 0))] = 0.25
np.testing.assert_allclose(a.dot(solution_a), b)
np.testing.assert_allclose(a.dot(solution_b), b)
def test_constraints_3d_second_order_first_derivative(self):
stencils = [np.array([-1, 0, 1])] * 3
a, b = layers_util.polynomial_accuracy_constraints(
stencils,
FINITE_DIFF,
derivative_orders=[0, 0, 1],
accuracy_order=2)
# ten constraints, for each term in [1, (3 choose 1), (3 choose 2)]
self.assertEqual(a.shape, (10, 27))
self.assertEqual(b.shape, (10,))
# explicitly test a few valid solutions
solution = np.zeros(27)
solution[_stencil_id([1, 1, 0], [3] * 3)] = -0.5
solution[_stencil_id([1, 1, 2], [3] * 3)] = 0.5
np.testing.assert_allclose(a.dot(solution), b)
class PolynomialAccuracyCoefficientsTests(test_util.TestCase):
# For test-cases, see
# https://en.wikipedia.org/wiki/Finite_difference_coefficient
@parameterized.parameters(
dict(stencil=[-1, 0, 1], derivative_order=1, expected=[-1 / 2, 0, 1 / 2]),
dict(stencil=[-1, 0, 1], derivative_order=2, expected=[1, -2, 1]),
dict(
stencil=[-2, -1, 0, 1, 2],
derivative_order=2,
expected=[-1 / 12, 4 / 3, -5 / 2, 4 / 3, -1 / 12]),
dict(
stencil=[-2, -1, 0, 1, 2],
derivative_order=2,
accuracy_order=1,
expected=[0, 1, -2, 1, 0]),
dict(stencil=[0, 1], derivative_order=1, expected=[-1, 1]),
dict(stencil=[0, 2], derivative_order=1, expected=[-0.5, 0.5]),
dict(stencil=[0, 0.5], derivative_order=1, expected=[-2, 2]),
dict(
stencil=[0, 1, 2, 3, 4],
derivative_order=4,
expected=[1, -4, 6, -4, 1]),
)
def test_finite_difference_coefficients_1d(
self,
stencil,
derivative_order,
expected,
accuracy_order=None
):
result = layers_util.polynomial_accuracy_coefficients(
[np.array(stencil)], FINITE_DIFF, [derivative_order], accuracy_order)
np.testing.assert_allclose(result, expected)
@parameterized.parameters(
dict(
stencils=[[-0.5, 0.5], [-0.5, 0.5]],
derivative_orders=[0, 0],
expected=[[0.25, 0.25], [0.25, 0.25]]),
dict(
stencils=[[-0.5, 0.5], [-0.5, 0.5]],
derivative_orders=[0, 1],
expected=[[-0.5, 0.5], [-0.5, 0.5]]),
dict(
stencils=[[-0.5, 0.5], [-0.5, 0.5]],
derivative_orders=[1, 1],
expected=[[1, -1], [-1, 1]]),
dict(
stencils=[[-1, 0, 1], [-0.5, 0.5]],
derivative_orders=[1, 0],
expected=[[-0.25, -0.25], [0, 0], [0.25, 0.25]]),
)
def test_finite_difference_coefficients_2d(self, stencils, derivative_orders,
expected):
args = ([np.array(s) for s in stencils], FINITE_DIFF, derivative_orders)
result = layers_util.polynomial_accuracy_coefficients(*args)
np.testing.assert_allclose(result, expected)
result = layers_util.polynomial_accuracy_coefficients(
*args, accuracy_order=1)
np.testing.assert_allclose(result, expected)
def test_finite_difference_coefficients_3d(self):
stencils = [[-0.5, 0.5] for _ in range(3)]
derivative_orders = [0, 0, 0]
expected = [
[[0.125, 0.125], [0.125, 0.125]], [[0.125, 0.125], [0.125, 0.125]]
]
args = ([np.array(s) for s in stencils], FINITE_DIFF, derivative_orders)
result = layers_util.polynomial_accuracy_coefficients(*args)
np.testing.assert_allclose(result, expected)
result = layers_util.polynomial_accuracy_coefficients(
*args, accuracy_order=1)
np.testing.assert_allclose(result, expected)
@parameterized.parameters(
dict(stencil=[-0.5, 0.5], derivative_order=0, expected=[1 / 2, 1 / 2]),
dict(
stencil=[-1.5, -0.5, 0.5, 1.5],
derivative_order=0,
accuracy_order=1,
expected=[0, 1 / 2, 1 / 2, 0]),
dict(stencil=[-1, 1], derivative_order=0, expected=[1 / 2, 1 / 2]),
dict(stencil=[-1.5, -0.5], derivative_order=0, expected=[-1 / 2, 3 / 2]),
dict(stencil=[-0.5, 0.5, 1.5],
derivative_order=0,
expected=[1 / 3, 5 / 6, -1 / 6]),
dict(stencil=[-0.5, 0.5], derivative_order=1, expected=[-1, 1]),
dict(stencil=[-1, 1], derivative_order=1, expected=[-1 / 2, 1 / 2]),
dict(stencil=[-1, 0, 1], derivative_order=1, expected=[-1 / 2, 0, 1 / 2]),
dict(stencil=[0.5, 1.5, 2.5], derivative_order=1, expected=[-2, 3, -1]),
dict(
stencil=[-1.5, -0.5, 0.5, 1.5],
derivative_order=1,
expected=[1 / 12, -5 / 4, 5 / 4, -1 / 12]),
dict(
stencil=[-.75, -0.25, 0.25, 0.75],
derivative_order=1,
expected=[1 / 6, -5 / 2, 5 / 2, -1 / 6]),
)
def test_finite_volume_coefficients_1d(self,
stencil,
derivative_order,
expected,
accuracy_order=None):
step = stencil[1] - stencil[0]
result = layers_util.polynomial_accuracy_coefficients(
[np.array(stencil)], FINITE_VOL, [derivative_order], accuracy_order,
grid_step=step)
np.testing.assert_allclose(result, expected)
class ExtractPatchesTest(test_util.TestCase):
@parameterized.named_parameters(
dict(testcase_name='_1D_even',
x=np.arange(10, dtype=np.float32).reshape([10, 1]),
patch_shape=(4,),
indices_and_patches=(
((0,), np.array([9., 0., 1., 2.])),
((5,), np.array([4., 5., 6., 7.])))),
dict(testcase_name='_1D_odd',
x=np.arange(10, dtype=np.float32).reshape([10, 1]),
patch_shape=(5,),
indices_and_patches=(
((0,), np.array([8., 9., 0., 1., 2.])),
((5,), np.array([3., 4., 5., 6., 7.])))),
dict(testcase_name='_2D_even',
x=np.arange(16, dtype=np.float32).reshape([4, 4, 1]),
patch_shape=(2, 2),
indices_and_patches=(
((0, 0), np.array([0., 1., 4., 5.])),
((1, 1), np.array([5, 6, 9, 10])))),
dict(testcase_name='_2D_odd',
x=np.arange(16, dtype=np.float32).reshape([4, 4, 1]),
patch_shape=(3, 3),
indices_and_patches=(
((0, 0), np.array([15., 12., 13., 3., 0., 1., 7., 4., 5.])),
((1, 1), np.array([0., 1., 2., 4., 5., 6., 8., 9., 10.])))),
dict(testcase_name='_3D_even',
x=np.arange(125, dtype=np.float32).reshape([5, 5, 5, 1]),
patch_shape=(2, 2, 2),
indices_and_patches=(
((0, 0, 0),
np.array([0., 1., 5., 6., 25., 26., 30., 31.])),
((3, 3, 3),
np.array([93., 94., 98., 99., 118., 119., 123., 124.])))),
dict(testcase_name='_3D_odd',
x=np.arange(125, dtype=np.float32).reshape([5, 5, 5, 1]),
patch_shape=(3, 3, 3),
indices_and_patches=(
((0, 0, 0),
np.array(
[124., 120., 121., 104., 100., 101., 109., 105., 106.,
24., 20., 21., 4., 0., 1., 9., 5., 6.,
49., 45., 46., 29., 25., 26., 34., 30., 31.])),
((3, 3, 3),
np.array(
[62., 63., 64., 67., 68., 69., 72., 73., 74.,
87., 88., 89., 92., 93., 94., 97., 98., 99.,
112., 113., 114., 117., 118., 119., 122., 123., 124.])))),
)
def test_extract_patches(self, x, patch_shape, indices_and_patches):
"""Tests `layers_util.extract_patches`."""
for method in ('roll', 'conv'):
with self.subTest(f'method_{method}'):
patches = layers_util.extract_patches(x, patch_shape, method=method)
for idx, expected_patch in indices_and_patches:
actual_patch = patches[idx]
np.testing.assert_allclose(actual_patch, expected_patch)
if __name__ == '__main__':
absltest.main()
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metrics for whirl experiments."""
import functools
from typing import Any, Callable, Optional, Tuple
import jax
import jax.numpy as jnp
from jax_cfd.base import array_utils as arr_utils
from jax_cfd.base import grids
Array = grids.Array
PyTree = Any
# TODO(shoyer): rewrite these metric functions in terms of vmap rather than
# explicit batch and time dimensions. Also consider go/clu-metrics.
def l2_loss_cumulative(trajectory: Tuple[Array, ...],
target: Tuple[Array, ...],
n: Optional[int] = None,
scale: float = 1,
time_axis=0) -> float:
"""Computes cumulative L2 loss on first `n` time slices."""
trajectory = arr_utils.slice_along_axis(trajectory, time_axis, slice(None, n))
target = arr_utils.slice_along_axis(target, time_axis, slice(None, n))
return sum((scale * jnp.square(x - y)).sum()
for x, y in zip(trajectory, target))
def l2_loss_single_step(trajectory: Tuple[Array, ...],
target: Tuple[Array, ...],
n: int,
scale: float = 1,
time_axis=0) -> float:
"""Computes L2 loss on `n`th time slice."""
trajectory = arr_utils.slice_along_axis(trajectory, time_axis, n)
target = arr_utils.slice_along_axis(target, time_axis, n)
return sum((scale * jnp.square(x - y)).sum()
for x, y in zip(trajectory, target))
def _normalize(array: Array, state_axes: Tuple[int, ...]) -> Array:
l2_norm = (array ** 2).sum(axis=state_axes, keepdims=True) ** 0.5
return array / l2_norm
def correlation_single_step(trajectory: Tuple[Array, ...],
target: Tuple[Array, ...],
n: int,
time_axis=0,
batch_axis=1) -> float:
"""Computes correlation on the `n`th time slice."""
trajectory = jnp.stack(arr_utils.slice_along_axis(trajectory, time_axis, n))
target = jnp.stack(arr_utils.slice_along_axis(target, time_axis, n))
state_axes = tuple(axis if axis <= time_axis else axis - 1
for axis in range(trajectory.ndim)
if axis != time_axis and axis != batch_axis)
trajectory_normalized = _normalize(trajectory, state_axes)
target_normalized = _normalize(target, state_axes)
return (trajectory_normalized * target_normalized).sum(axis=state_axes).mean()
def local_reduce(metric: Callable[..., Array],
reduction_function: Callable[[Array, int], Array],
batch_axis: int = 0) -> Callable[..., Array]:
"""Computes the mean of a metric over a local batch.
Example usage:
```
average_l2_loss = metrics.local_reduction(
functools.partial(metrics.l2_loss_cumulative, n=10), jnp.mean)
```
Args:
metric: a callable that returns a single array.
reduction_function: a callable that takes arguments `x, axis` and returns
a single array. For example, `jnp.mean`.
batch_axis: an integer indicating the batch dimension. Defaults to 0.
Returns:
A function that takes the same arguments as `metric` but computes the mean
along the batch axis.
"""
def reduced_metric(*args, **kwargs):
metric_value = jax.vmap(metric, batch_axis)(*args, **kwargs)
return reduction_function(metric_value, batch_axis)
return reduced_metric
def distributed_reduce(metric: Callable[..., Array],
reduction_function: Callable[[Array, str], Array],
axis_name: str = 'batch') -> Callable[..., Array]:
"""Computes the mean of a metric over a distributed batch.
Note that the functions returned are only suitable for use inside another
function that is pmapped along the same axis.
Example usage:
```
average_l2_loss = metrics.distributed_reduce(
functools.partial(metrics.l2_loss_cumulative, n=10),
reduction_function=jax.lax.pmean)
@functools.partial(jax.pmap, axis_name='batch')
def distributed_train_step(...):
prediction = ...
loss = average_l2_loss(prediction, target)
```
Args:
metric: a callable that takes locally batched arrays and returns a single
array.
reduction_function: a callable that takes arguments `x, axis_name` and
returns a single array. For example, `jax.lax.pmean`.
axis_name: the name that will be used for distributing metric computation
and combining results
Returns:
A function that takes the same arguments as `metric` but computes the mean
along the axis specified by `axis_name`.
"""
def reduced_metric(*args, **kwargs):
metric_value = metric(*args, **kwargs)
return reduction_function(metric_value, axis_name)
return reduced_metric
local_mean = functools.partial(local_reduce, reduction_function=jnp.mean)
local_sum = functools.partial(local_reduce, reduction_function=jnp.sum)
distributed_mean = functools.partial(
distributed_reduce, reduction_function=jax.lax.pmean)
distributed_sum = functools.partial(
distributed_reduce, reduction_function=jax.lax.psum)
"""Defines AbstractModel API, standard implementations and helper functions."""
import functools
from typing import Callable, Optional
import gin
import haiku as hk
from jax_cfd.base import grids
# Note: decoders, encoders and equations contain standard gin-configurables;
from jax_cfd.ml import decoders # pylint: disable=unused-import
from jax_cfd.ml import encoders # pylint: disable=unused-import
from jax_cfd.ml import equations # pylint: disable=unused-import
from jax_cfd.ml import physics_specifications
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
def _identity(x):
return x
class DynamicalSystem(hk.Module):
"""Abstract class for modeling dynamical systems."""
def __init__(
self,
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
name: Optional[str] = None
):
"""Constructs an instance of a class."""
super().__init__(name=name)
self.grid = grid
self.dt = dt
self.physics_specs = physics_specs
def encode(self, x):
"""Encodes input trajectory `x` to the model state."""
raise NotImplementedError("Model subclass did not define encode")
def decode(self, x):
"""Decodes a model state `x` to a data representation."""
raise NotImplementedError("Model subclass did not define decode")
def advance(self, x):
"""Returns a model state `x` advanced in time by `self.dt`."""
raise NotImplementedError("Model subclass did not define advance")
def trajectory(
self,
x,
outer_steps: int,
inner_steps: int = 1,
*,
start_with_input: bool = False,
post_process_fn: Callable = _identity,
):
"""Returns a final model state and trajectory."""
return trajectory_from_step(
self.advance, outer_steps, inner_steps,
start_with_input=start_with_input,
post_process_fn=post_process_fn
)(x)
@gin.register
class ModularStepModel(DynamicalSystem):
"""Dynamical model based on independent encoder/decoder/step components."""
def __init__(
self,
grid: grids.Grid,
dt: float,
physics_specs: physics_specifications.BasePhysicsSpecs,
advance_module=gin.REQUIRED,
encoder_module=gin.REQUIRED,
decoder_module=gin.REQUIRED,
name: Optional[str] = None
):
"""Constructs an instance of a class."""
super().__init__(grid=grid, dt=dt, physics_specs=physics_specs, name=name)
self.advance_module = advance_module(grid, dt, physics_specs)
self.encoder_module = encoder_module(grid, dt, physics_specs)
self.decoder_module = decoder_module(grid, dt, physics_specs)
def encode(self, x):
return self.encoder_module(x)
def decode(self, x):
return self.decoder_module(x)
def advance(self, x):
return self.advance_module(x)
@gin.configurable
def get_model_cls(grid, dt, physics_specs, model_cls=gin.REQUIRED):
"""Returns a configured model class."""
return functools.partial(model_cls, grid, dt, physics_specs)
def repeated(fn: Callable, steps: int) -> Callable:
"""Returns a repeatedly applied version of fn()."""
def f_repeated(x_initial):
g = lambda x, _: (fn(x), None)
x_final, _ = hk.scan(g, x_initial, xs=None, length=steps)
return x_final
return f_repeated
@gin.configurable(allowlist=("set_checkpoint",))
def trajectory_from_step(
step_fn: Callable,
outer_steps: int,
inner_steps: int,
*,
start_with_input: bool,
post_process_fn: Callable,
set_checkpoint: bool = False,
):
"""Returns a function that accumulates repeated applications of `step_fn`.
Compute a trajectory by repeatedly calling `step_fn()`
`outer_steps * inner_steps` times.
Args:
step_fn: function that takes a state and returns state after one time step.
outer_steps: number of steps to save in the generated trajectory.
inner_steps: number of repeated calls to step_fn() between saved steps.
start_with_input: if True, output the trajectory at steps [0, ..., steps-1]
instead of steps [1, ..., steps].
post_process_fn: function to apply to trajectory outputs.
set_checkpoint: whether to use `jax.checkpoint` on `step_fn`.
Returns:
A function that takes an initial state and returns a tuple consisting of:
(1) the final frame of the trajectory.
(2) trajectory of length `outer_steps` representing time evolution.
"""
if set_checkpoint:
step_fn = hk.remat(step_fn)
if inner_steps != 1:
step_fn = repeated(step_fn, inner_steps)
def step(carry_in, _):
carry_out = step_fn(carry_in)
frame = carry_in if start_with_input else carry_out
return carry_out, post_process_fn(frame)
def multistep(x):
return hk.scan(step, x, xs=None, length=outer_steps)
return multistep
"""Helper methods for constructing trajectory functions in model_builder.py."""
import functools
from jax_cfd.base import array_utils
def with_preprocessing(fn, preprocess_fn):
"""Generates a function that computes `fn` on `preprocess_fn(x)`."""
@functools.wraps(fn)
def apply_fn(x, *args, **kwargs):
return fn(preprocess_fn(x), *args, **kwargs)
return apply_fn
def with_post_processing(fn, post_process_fn):
"""Generates a function that applies `post_process_fn` to outputs of `fn`."""
@functools.wraps(fn)
def apply_fn(*args, **kwargs):
return post_process_fn(*fn(*args, **kwargs))
return apply_fn
def with_split_input(fn, split_index, time_axis=0):
"""Decorates `fn` to be evaluated on first `split_index` time slices.
The returned function is a generalization to pytrees of the function:
`fn(x[:split_index], *args, **kwargs)`
Args:
fn: function to be transformed.
split_index: number of input elements along the time axis to use.
time_axis: axis corresponding to time dimension in `x` to decorated `fn`.
Returns:
decorated `fn` that is evaluated on only `split_index` first time slices of
provided inputs.
"""
@functools.wraps(fn)
def apply_fn(x, *args, **kwargs):
init, _ = array_utils.split_along_axis(x, split_index, axis=time_axis)
return fn(init, *args, **kwargs)
return apply_fn
def with_input_included(trajectory_fn, time_axis=0):
"""Returns a `trajectory_fn` that concatenates inputs `x` to trajectory."""
@functools.wraps(trajectory_fn)
def _trajectory(x, *args, **kwargs):
final, unroll = trajectory_fn(x, *args, **kwargs)
return final, array_utils.concat_along_axis([x, unroll], time_axis)
return _trajectory
def decoded_trajectory_with_inputs(model, num_init_frames):
"""Returns trajectory_fn operating on decoded data.
The returned function uses `num_init_frames` of the physics space trajectory
provided as an input to initialize the model state, unrolls the trajectory of
specified length that is decoded to the physics space using `model.decode_fn`.
Args:
model: model of a dynamical system used to obtain the trajectory.
num_init_frames: number of time frames used from the physics trajectory to
initialize the model state.
Returns:
Trajectory function that operates on physics space trajectories and returns
unrolls in physics space.
"""
def _trajectory_fn(x, outer_steps, inner_steps=1):
trajectory_fn = functools.partial(
model.trajectory, post_process_fn=model.decode)
# add preprocessing to convert data to model state.
trajectory_fn = with_preprocessing(trajectory_fn, model.encode)
# concatenate input trajectory to output trajectory for easier comparison.
trajectory_fn = with_input_included(trajectory_fn)
# make trajectories operate on full examples by splitting the init.
trajectory_fn = with_split_input(trajectory_fn, num_init_frames)
return trajectory_fn(x, outer_steps, inner_steps)
return _trajectory_fn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment