interpolations.py 8.4 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""Interpolation modules."""

import collections
import functools
from typing import Any, Callable, Tuple, Union

import gin
import jax.numpy as jnp
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.ml import layers
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import towers
import numpy as np


GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = interpolation.InterpolationFn
InterpolationModule = Callable[..., InterpolationFn]
InterpolationTransform = Callable[..., InterpolationFn]
FluxLimiter = interpolation.FluxLimiter


StencilSizeFn = Callable[
    [Tuple[int, ...], Tuple[int, ...], Any], Tuple[int, ...]]


@gin.register
class FusedLearnedInterpolation:
  """Learned interpolator that computes interpolation coefficients in 1 pass.

  Interpolation function that has pre-computed interpolation
  coefficients for a given velocity field `v`. It uses a collection of
  `SpatialDerivativeFromLogits` modules and a single neural network that
  produces logits for all expected interpolations. Interpolations are keyed by
  `input_offset`, `target_offset` and an optional `tag`. The `tag` allows us to
  perform multiple interpolations between the same `offset` and `target_offset`
  with different weights.
  """

  def __init__(
      self,
      grid: grids.Grid,
      dt: float,
      physics_specs: physics_specifications.BasePhysicsSpecs,
      v,
      tags=(None,),
      stencil_size: Union[int, StencilSizeFn] = 4,
      tower_factory=towers.forward_tower_factory,
      name='fused_learned_interpolation',
      extract_patch_method='roll',
      fuse_constraints=False,
      fuse_patches=False,
      constrain_with_conv=False,
      tile_layout=None,
  ):
    """Constructs object and performs necessary pre-computate."""
    del dt, physics_specs  # unused.

    derivative_orders = (0,) * grid.ndim
    derivatives = collections.OrderedDict()

    if isinstance(stencil_size, int):
      stencil_size_fn = lambda *_: (stencil_size,) * grid.ndim
    else:
      stencil_size_fn = stencil_size

    for u in v:
      for target_offset in grids.control_volume_offsets(u):
        for tag in tags:
          key = (u.offset, target_offset, tag)
          derivatives[key] = layers.SpatialDerivativeFromLogits(
              stencil_size_fn(*key),
              u.offset,
              target_offset,
              derivative_orders=derivative_orders,
              steps=grid.step,
              extract_patch_method=extract_patch_method,
              tile_layout=tile_layout)

    output_sizes = [deriv.subspace_size for deriv in derivatives.values()]
    cnn_network = tower_factory(sum(output_sizes), grid.ndim, name=name)
    inputs = jnp.stack([u.data for u in v], axis=-1)
    all_logits = cnn_network(inputs)

    if fuse_constraints:
      self._interpolators = layers.fuse_spatial_derivative_layers(
          derivatives, all_logits, fuse_patches=fuse_patches,
          constrain_with_conv=constrain_with_conv)
    else:
      split_logits = jnp.split(all_logits, np.cumsum(output_sizes), axis=-1)
      self._interpolators = {
          k: functools.partial(derivative, logits=logits)
          for (k, derivative), logits in zip(derivatives.items(), split_logits)
      }

  def __call__(self,
               c: GridVariable,
               offset: Tuple[int, ...],
               v: GridVariableVector,
               dt: float,
               tag=None) -> GridVariable:
    del dt  # not used.
    # TODO(dkochkov) Add decorator to expand/squeeze channel dim.
    c = grids.GridVariable(
        grids.GridArray(jnp.expand_dims(c.data, -1), c.offset, c.grid), c.bc)
    # TODO(jamieas): Try removing the following line.
    if c.offset == offset: return c
    key = (c.offset, offset, tag)
    interpolator = self._interpolators.get(key)
    if interpolator is None:
      raise KeyError(f'No interpolator for key {key}. '
                     f'Available keys: {list(self._interpolators.keys())}')
    result = jnp.squeeze(interpolator(c.data), axis=-1)
    return grids.GridVariable(
        grids.GridArray(result, offset, c.grid), c.bc)


def _nearest_neighhbor_stencil_size_fn(
    source_offset, target_offset, tag, stencil_size,
):
  del tag  # unused
  return tuple(
      1 if s == t else stencil_size
      for s, t in zip(source_offset, target_offset)
  )


@gin.register
def anisotropic_learned_interpolation(*args, stencil_size=2, **kwargs):
  """Like FusedLearnedInterpolation, but with anisotropic stencil."""
  stencil_size_fn = functools.partial(
      _nearest_neighhbor_stencil_size_fn, stencil_size=stencil_size,
  )
  return FusedLearnedInterpolation(
      *args, stencil_size=stencil_size_fn, **kwargs
  )


@gin.register
class IndividualLearnedInterpolation:
  """Trainable interpolation module.

  This module uses a collection of SpatialDerivative modules that are applied
  to inputs based on the combination of initial and target offsets. Currently
  no symmetries are implemented and every new pair of offsets gets a separate
  network.
  """

  def __init__(
      self,
      grid: grids.Grid,
      dt: float,
      physics_specs: physics_specifications.BasePhysicsSpecs,
      v: GridArrayVector,
      stencil_size=4,
      tower_factory=towers.forward_tower_factory,
  ):
    del v, dt, physics_specs  # unused.
    self._ndim = grid.ndim
    self._tower_factory = functools.partial(tower_factory, ndim=grid.ndim)
    self._stencil_sizes = (stencil_size,) * self._ndim
    self._steps = grid.step
    self._modules = {}

  def _get_interpolation_module(self, offsets):
    """Constructs or retrieves a learned interpolation module."""
    if offsets in self._modules:
      return self._modules[offsets]
    inputs_offset, target_offset = offsets
    self._modules[offsets] = layers.SpatialDerivative(
        self._stencil_sizes, inputs_offset, target_offset,
        (0,) * self._ndim, self._tower_factory, self._steps)
    return self._modules[offsets]

  def __call__(
      self,
      c: GridVariable,
      offset: Tuple[int, ...],
      v: GridVariableVector,
      dt: float,
  ) -> GridVariable:
    """Interpolates `c` to `offset`."""
    del dt  # not used.
    if c.offset == offset: return c
    offsets = (c.offset, offset)
    c_input = jnp.expand_dims(c.data, axis=-1)
    aux_inputs = [jnp.expand_dims(u.data, axis=-1) for u in v]
    res = self._get_interpolation_module(offsets)(c_input, *aux_inputs)
    return grids.GridVariable(
        grids.GridArray(jnp.squeeze(res, axis=-1), offset, c.grid), c.bc)


@gin.register
def linear(*args, **kwargs):
  del args, kwargs
  return interpolation.linear


@gin.register
def upwind(*args, **kwargs):
  del args, kwargs
  return interpolation.upwind


@gin.register
def lax_wendroff(*args, **kwargs):
  del args, kwargs
  return interpolation.lax_wendroff


# TODO(dkochkov) make flux limiters configurable.
@gin.register
def tvd_limiter_transformation(
    interpolation_fn: InterpolationFn,
    limiter_fn: FluxLimiter = interpolation.van_leer_limiter,
) -> InterpolationFn:
  """Transformation function that applies flux limiter to `interpolation_fn`."""
  return interpolation.apply_tvd_limiter(interpolation_fn, limiter_fn)


@gin.register
def transformed(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    v: GridArrayVector,
    base_interpolation_module: InterpolationModule = lax_wendroff,
    transformation: InterpolationTransform = tvd_limiter_transformation,
) -> InterpolationFn:
  """Interpolation module that augments interpolation of the base module.

  This module generates interpolation method that consists of that generated
  by `base_interpolation_module` transformed by `transformation`. This allows
  implementation of additional constraints such as TVD, in which case
  `transformation` should apply a TVD limiter.

  Args:
    grid: grid on which the Navier-Stokes equation is discretized.
    dt: time step to use for time evolution.
    physics_specs: physical parameters of the simulation module.
    v: input velocity field potentially used to pre-compute interpolations.
    base_interpolation_module: base interpolation module to use.
    transformation: transformation to apply to base interpolation function.

  Returns:
    Interpolation function.
  """
  interpolation_fn = base_interpolation_module(grid, dt, physics_specs, v=v)
  interpolation_fn = transformation(interpolation_fn)
  return interpolation_fn