equations.py 9.16 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""Implementations of equation modules."""

from typing import Any, Callable, Tuple

import gin
import haiku as hk
import jax
import jax.numpy as jnp
from jax_cfd import spectral
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import equations
from jax_cfd.base import grids
from jax_cfd.ml import advections
from jax_cfd.ml import diffusions
from jax_cfd.ml import forcings
from jax_cfd.ml import networks  # pylint: disable=unused-import
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import pressures
from jax_cfd.ml import time_integrators
from jax_cfd.spectral import utils as spectral_utils

ConvectionModule = advections.ConvectionModule
DiffuseModule = diffusions.DiffuseModule
DiffusionSolveModule = diffusions.DiffusionSolveModule
ForcingModule = forcings.ForcingModule
PressureModule = pressures.PressureModule
# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic


# TODO(dkochkov) move diffusion to modular_navier_stokes after b/160947162.
@gin.register(denylist=("grid", "dt", "physics_specs"))
def semi_implicit_navier_stokes(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
    diffusion_module: DiffuseModule = diffusions.diffuse,
    **kwargs,
):
  """Semi-implicit navier stokes solver compatible with explicit diffusion."""
  diffusion = diffusion_module(grid, dt, physics_specs)
  step_fn = equations.semi_implicit_navier_stokes(
      diffuse=diffusion, grid=grid, dt=dt, **kwargs)
  return hk.to_module(step_fn)()


@gin.register(denylist=("grid", "dt", "physics_specs"))
def implicit_diffusion_navier_stokes(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
    diffusion_module: DiffusionSolveModule = diffusions.solve_fast_diag,
    **kwargs
):
  """Implicit navier stokes solver compatible with implicit diffusion."""
  diffusion = diffusion_module(grid, dt, physics_specs)
  step_fn = equations.implicit_diffusion_navier_stokes(
      diffusion_solve=diffusion, grid=grid, dt=dt, **kwargs)
  return hk.to_module(step_fn)()


@gin.register(denylist=("grid", "dt", "physics_specs"))
def modular_spectral_step_fn(
    grid,
    dt,
    physics_specs,
    do_filter_step=False,
    time_stepper=spectral.time_stepping.crank_nicolson_rk4,
    ):
  """Returns a spectral solver for Forced Navier-Stokes flows."""
  eq = spectral.equations.NavierStokes2D(
      physics_specs.viscosity,
      grid,
      drag=physics_specs.drag,
      forcing_fn=physics_specs.forcing_module,
      smooth=physics_specs.smooth)

  step_fn = time_stepper(eq, dt)
  if do_filter_step:
    # lambdas don't place nice with gin config.
    def ret(vhat):
      v = jnp.fft.irfft2(step_fn(vhat))  # TODO(dresdner) unnecessary fft's
      return jnp.fft.rfft2(spectral_utils.exponential_filter(v))
  else:
    ret = step_fn

  return hk.to_module(ret)()


@gin.configurable(denylist=("grid", "dt", "physics_specs"))
def modular_navier_stokes_model(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
    equation_solver=implicit_diffusion_navier_stokes,
    convection_module: ConvectionModule = advections.self_advection,
    pressure_module: PressureModule = pressures.fast_diagonalization,
    acceleration_modules=(),
):
  """Returns an incompressible Navier-Stokes time step model.

  This model is derived from standard components of numerical solvers that could
  be replaced with learned components. Note that diffusion module is specified
  in the equation_solver due to differences in implicit/explicit schemes.

  Args:
    grid: grid on which the Navier-Stokes equation is discretized.
    dt: time step to use for time evolution.
    physics_specs: physical parameters of the simulation module.
    equation_solver: solver to call to create a time-stepping function.
    convection_module: module to use to simulate convection.
    pressure_module: module to use to perform pressure projection.
    acceleration_modules: additional explicit terms to be adde to the equation
      before the pressure projection step.

  Returns:
    A function that performs `steps` steps of the Navier-Stokes time dynamics.
  """
  active_forcing_fn = physics_specs.forcing_module(grid)

  def navier_stokes_step_fn(state):
    """Advances Navier-Stokes state forward in time."""
    v = state
    for u in v:
      if not isinstance(u, grids.GridVariable):
        raise ValueError(f"Expected GridVariable type, got {type(u)}")
    convection = convection_module(grid, dt, physics_specs, v=v)
    accelerations = [
        acceleration_module(grid, dt, physics_specs, v=v)
        for acceleration_module in acceleration_modules
    ]
    forcing = forcings.sum_forcings(active_forcing_fn, *accelerations)
    pressure_solve_fn = pressure_module(grid, dt, physics_specs)
    step_fn = equation_solver(
        grid=grid,
        dt=dt,
        physics_specs=physics_specs,
        density=physics_specs.density,
        viscosity=physics_specs.viscosity,
        pressure_solve=pressure_solve_fn,
        convect=convection,
        forcing=forcing)
    return step_fn(v)

  return hk.to_module(navier_stokes_step_fn)()


@gin.register
def time_derivative_network_model(
    grid: grids.Grid,
    dt: float,
    physics_specs: Any,
    derivative_modules: Tuple[Callable, ...],
    time_integrator=time_integrators.euler_integrator,
):
  """Returns a ML model that performs time stepping by time integration.

  Note: the model state is assumed to be a stack of observable values
  along the last axis.

  Args:
    grid: grid specifying spatial discretization of the physical system.
    dt: time step to use for time evolution.
    physics_specs: physical parameters of the simulation module.
    derivative_modules: tuple of modules that are used sequentially to compute
      unforced time derivative of the input state, which is then integrated.
    time_integrator: time integration scheme to use.

  Returns:
    `step_fn` that advances the input state forward in time by `dt`.
  """
  active_forcing_fn = physics_specs.forcing_module(grid)

  def step_fn(state):
    """Advances `state` forward in time by `dt`."""
    modules = [module(grid, dt, physics_specs) for module in derivative_modules]

    def time_derivative_fn(x):
      v = array_utils.split_axis(x, axis=-1)  # Tuple[DeviceArray, ...]
      v = tuple(grids.GridArray(u, o, grid) for u, o in zip(v, grid.cell_faces))
      # TODO(pnorgaard) Explicitly specify boundary conditions for ML model
      bc = boundaries.periodic_boundary_conditions(grid.ndim)
      v = tuple(grids.GridVariable(u, bc) for u in v)
      forcing_scalars = jnp.stack(
          [a.data for a in active_forcing_fn(v)], axis=-1)
      # TODO(dkochkov) consider conditioning on the forcing terms.
      for module_fn in modules:
        x = module_fn(x)
      return x + forcing_scalars

    time_derivative_module = hk.to_module(time_derivative_fn)()
    out, _ = time_integrator(time_derivative_module, state, dt, 1)
    return out

  return hk.to_module(step_fn)()


@gin.register
def identity_model(grid, dt, physics_specs):
  """A model that just returns the original state."""
  del grid, dt, physics_specs
  def step_fn(state):
    return state
  return step_fn


@gin.register
def learned_corrector(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    base_solver_module: Callable,
    corrector_module: Callable,
):
  """Returns a model that uses base solver with ML correction step."""
  # Idea similar to solver in the loop in https://arxiv.org/abs/2007.00016 and
  # learned corrector in https://arxiv.org/pdf/2102.01010.pdf.
  base_solver = base_solver_module(grid, dt, physics_specs)
  corrector = corrector_module(grid, dt, physics_specs)

  def step_fn(state):
    next_state = base_solver(state)
    corrections = corrector(next_state)
    return jax.tree_util.tree_map(lambda x, y: x + y, next_state, corrections)

  return hk.to_module(step_fn)()


@gin.register
def learned_corrector_v2(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    base_solver_module: Callable,
    corrector_module: Callable,
):
  """Like learned_corrector, but based on the input rather than output state."""
  base_solver = base_solver_module(grid, dt, physics_specs)
  corrector = corrector_module(grid, dt, physics_specs)

  def step_fn(state):
    next_state = base_solver(state)
    corrections = corrector(state)
    return jax.tree_util.tree_map(lambda x, y: x + dt * y, next_state, corrections)

  return hk.to_module(step_fn)()


@gin.register
def learned_corrector_v3(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    base_solver_module: Callable,
    corrector_module: Callable,
):
  """Like learned_corrector, but based on input & output states."""
  base_solver = base_solver_module(grid, dt, physics_specs)
  corrector = corrector_module(grid, dt, physics_specs)

  def step_fn(state):
    next_state = base_solver(state)
    corrections = corrector(tuple(state) + tuple(next_state))
    return jax.tree_util.tree_map(lambda x, y: x + dt * y, next_state, corrections)

  return hk.to_module(step_fn)()