interpolation.py 13.5 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for interpolating `GridVariables`s."""

from typing import Callable, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
import numpy as np

Array = Union[np.ndarray, jax.Array]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = Callable[
    [GridVariable, Tuple[float, ...], GridVariableVector, float],
    GridVariable]
FluxLimiter = Callable[[grids.Array], grids.Array]


def _linear_along_axis(c: GridVariable,
                       offset: float,
                       axis: int) -> GridVariable:
  """Linear interpolation of `c` to `offset` along a single specified `axis`."""
  offset_delta = offset - c.offset[axis]

  # If offsets are the same, `c` is unchanged.
  if offset_delta == 0:
    return c

  new_offset = tuple(offset if j == axis else o
                     for j, o in enumerate(c.offset))

  # If offsets differ by an integer, we can just shift `c`.
  if int(offset_delta) == offset_delta:
    return grids.GridVariable(
        array=grids.GridArray(data=c.shift(int(offset_delta), axis).data,
                              offset=new_offset,
                              grid=c.grid),
        bc=c.bc)

  floor = int(np.floor(offset_delta))
  ceil = int(np.ceil(offset_delta))
  floor_weight = ceil - offset_delta
  ceil_weight = 1. - floor_weight
  data = (floor_weight * c.shift(floor, axis).data +
          ceil_weight * c.shift(ceil, axis).data)
  return grids.GridVariable(
      array=grids.GridArray(data, new_offset, c.grid), bc=c.bc)


def linear(
    c: GridVariable,
    offset: Tuple[float, ...],
    v: Optional[GridVariableVector] = None,
    dt: Optional[float] = None
) -> grids.GridVariable:
  """Multi-linear interpolation of `c` to `offset`.

  Args:
    c: quantitity to be interpolated.
    offset: offset to which we will interpolate `c`. Must have the same length
      as `c.offset`.
    v: velocity field. Not used.
    dt: size of the time step. Not used.

  Returns:
    An `GridArray` containing the values of `c` after linear interpolation
    to `offset`. The returned value will have offset equal to `offset`.
  """
  del v, dt  # unused
  if len(offset) != len(c.offset):
    raise ValueError('`c.offset` and `offset` must have the same length;'
                     f'got {c.offset} and {offset}.')
  interpolated = c
  for a, o in enumerate(offset):
    interpolated = _linear_along_axis(interpolated, offset=o, axis=a)
  return interpolated


def upwind(
    c: GridVariable,
    offset: Tuple[float, ...],
    v: GridVariableVector,
    dt: Optional[float] = None
) -> GridVariable:
  """Upwind interpolation of `c` to `offset` based on velocity field `v`.

  Interpolates values of `c` to `offset` in two steps:
  1) Identifies the axis along which `c` is interpolated. (must be single axis)
  2) For positive (negative) velocity along interpolation axis uses value from
     the previous (next) cell along that axis correspondingly.

  Args:
    c: quantitity to be interpolated.
    offset: offset to which `c` will be interpolated. Must have the same
      length as `c.offset` and differ in at most one entry.
    v: velocity field with offsets at faces of `c`. One of the components
      must have the same offset as `offset`.
    dt: size of the time step. Not used.

  Returns:
    A `GridVariable` that containins the values of `c` after interpolation to
    `offset`.

  Raises:
    InconsistentOffsetError: if `offset` and `c.offset` differ in more than one
    entry.
  """
  del dt  # unused
  if c.offset == offset: return c
  interpolation_axes = tuple(
      axis for axis, (current, target) in enumerate(zip(c.offset, offset))
      if current != target
  )
  if len(interpolation_axes) != 1:
    raise grids.InconsistentOffsetError(
        f'for upwind interpolation `c.offset` and `offset` must differ at most '
        f'in one entry, but got: {c.offset} and {offset}.')
  axis, = interpolation_axes
  u = v[axis]
  offset_delta = u.offset[axis] - c.offset[axis]

  # If offsets differ by an integer, we can just shift `c`.
  if int(offset_delta) == offset_delta:
    return grids.GridVariable(
        array=grids.GridArray(data=c.shift(int(offset_delta), axis).data,
                              offset=offset,
                              grid=grids.consistent_grid(c, u)),
        bc=c.bc)

  floor = int(np.floor(offset_delta))
  ceil = int(np.ceil(offset_delta))
  array = grids.applied(jnp.where)(
      u.array > 0, c.shift(floor, axis).data, c.shift(ceil, axis).data
  )
  grid = grids.consistent_grid(c, u)
  return grids.GridVariable(
      array=grids.GridArray(array.data, offset, grid),
      bc=boundaries.periodic_boundary_conditions(grid.ndim))


def lax_wendroff(
    c: GridVariable,
    offset: Tuple[float, ...],
    v: Optional[GridVariableVector] = None,
    dt: Optional[float] = None
) -> GridVariable:
  """Lax_Wendroff interpolation of `c` to `offset` based on velocity field `v`.

  Interpolates values of `c` to `offset` in two steps:
  1) Identifies the axis along which `c` is interpolated. (must be single axis)
  2) For positive (negative) velocity along interpolation axis uses value from
     the previous (next) cell along that axis plus a correction originating
     from expansion of the solution at the half step-size.

  This method is second order accurate with fixed coefficients and hence can't
  be monotonic due to Godunov's theorem.
  https://en.wikipedia.org/wiki/Godunov%27s_theorem

  Lax-Wendroff method can be used to form monotonic schemes when augmented with
  a flux limiter. See https://en.wikipedia.org/wiki/Flux_limiter

  Args:
    c: quantitity to be interpolated.
    offset: offset to which we will interpolate `c`. Must have the same
      length as `c.offset` and differ in at most one entry.
    v: velocity field with offsets at faces of `c`. One of the components must
      have the same offset as `offset`.
    dt: size of the time step. Not used.

  Returns:
    A `GridVariable` that containins the values of `c` after interpolation to
    `offset`.
  Raises:
    InconsistentOffsetError: if `offset` and `c.offset` differ in more than one
    entry.
  """
  # TODO(dkochkov) add a function to compute interpolation axis.
  if c.offset == offset: return c
  interpolation_axes = tuple(
      axis for axis, (current, target) in enumerate(zip(c.offset, offset))
      if current != target
  )
  if len(interpolation_axes) != 1:
    raise grids.InconsistentOffsetError(
        f'for Lax-Wendroff interpolation `c.offset` and `offset` must differ at'
        f' most in one entry, but got: {c.offset} and {offset}.')
  axis, = interpolation_axes
  u = v[axis]
  offset_delta = u.offset[axis] - c.offset[axis]
  floor = int(np.floor(offset_delta))  # used for positive velocity
  ceil = int(np.ceil(offset_delta))  # used for negative velocity
  grid = grids.consistent_grid(c, u)
  courant_numbers = (dt / grid.step[axis]) * u.data
  positive_u_case = (
      c.shift(floor, axis).data + 0.5 * (1 - courant_numbers) *
      (c.shift(ceil, axis).data - c.shift(floor, axis).data))
  negative_u_case = (
      c.shift(ceil, axis).data - 0.5 * (1 + courant_numbers) *
      (c.shift(ceil, axis).data - c.shift(floor, axis).data))
  array = grids.where(u.array > 0, positive_u_case, negative_u_case)
  grid = grids.consistent_grid(c, u)
  return grids.GridVariable(
      array=grids.GridArray(array.data, offset, grid),
      bc=boundaries.periodic_boundary_conditions(grid.ndim))


def safe_div(x, y, default_numerator=1):
  """Safe division of `Array`'s."""
  return x / jnp.where(y != 0, y, default_numerator)


def van_leer_limiter(r):
  """Van-leer flux limiter."""
  return jnp.where(r > 0, safe_div(2 * r, 1 + r), 0.0)


def apply_tvd_limiter(
    interpolation_fn: InterpolationFn,
    limiter: FluxLimiter = van_leer_limiter
) -> InterpolationFn:
  """Combines low and high accuracy interpolators to get TVD method.

  Generates high accuracy interpolator by combining stable lwo accuracy `upwind`
  interpolation and high accuracy (but not guaranteed to be stable)
  `interpolation_fn` to obtain stable higher order method. This implementation
  follows the procedure outined in:
  http://www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf

  Args:
    interpolation_fn: higher order interpolation methods. Must follow the same
      interface as other interpolation methods (take `c`, `offset`, `grid`, `v`
      and `dt` arguments and return value of `c` at offset `offset`).
    limiter: flux limiter function that evaluates the portion of the correction
      (high_accuracy - low_accuracy) to add to low_accuracy solution based on
      the ratio of the consequtive gradients. Takes array as input and return
      array of weights. For more details see:
      https://en.wikipedia.org/wiki/Flux_limiter

  Returns:
    Interpolation method that uses a combination of high and low order methods
    to produce monotonic interpolation method.
  """

  def tvd_interpolation(
      c: GridVariable,
      offset: Tuple[float, ...],
      v: GridVariableVector,
      dt: float,
  ) -> GridVariable:
    """Interpolated `c` to offset `offset`."""
    for axis, axis_offset in enumerate(offset):
      interpolation_offset = tuple([
          c_offset if i != axis else axis_offset
          for i, c_offset in enumerate(c.offset)
      ])
      if interpolation_offset != c.offset:
        if interpolation_offset[axis] - c.offset[axis] != 0.5:
          raise NotImplementedError('tvd_interpolation only supports forward '
                                    'interpolation to control volume faces.')
        c_low = upwind(c, offset, v, dt)
        c_high = interpolation_fn(c, offset, v, dt)

        # because we are interpolating to the right we are using 2 points ahead
        # and 2 points behind: `c`, `c_left`.
        c_left = c.shift(-1, axis)
        c_right = c.shift(1, axis)
        c_next_right = c.shift(2, axis)
        # Velocities of different sign are evaluated with limiters at different
        # points. See equations (4.34) -- (4.39) from the reference above.
        positive_u_r = safe_div(c.data - c_left.data, c_right.data - c.data)
        negative_u_r = safe_div(c_next_right.data - c_right.data,
                                c_right.data - c.data)
        positive_u_phi = grids.GridArray(
            limiter(positive_u_r), c_low.offset, c.grid)
        negative_u_phi = grids.GridArray(
            limiter(negative_u_r), c_low.offset, c.grid)
        u = v[axis]
        phi = grids.applied(jnp.where)(
            u.array > 0, positive_u_phi, negative_u_phi)
        c_interpolated = c_low.array - (c_low.array - c_high.array) * phi
        c = grids.GridVariable(
            grids.GridArray(c_interpolated.data, interpolation_offset, c.grid),
            c.bc)
    return c

  return tvd_interpolation


# TODO(pnorgaard) Consider changing c to GridVariable
# Not required since no .shift() method is used
def point_interpolation(
    point: Array,
    c: GridArray,
    order: int = 1,
    mode: str = 'nearest',
    cval: float = 0.0,
) -> jax.Array:
  """Interpolate `c` at `point`.

  Args:
    point: length N 1-D Array. The point to interpolate to.
    c: N-dimensional GridArray. The values that will be interpolated.
    order: Integer in the range 0-1. The order of the spline interpolation.
    mode: one of {'reflect', 'constant', 'nearest', 'mirror', 'wrap'}.
      The `mode` parameter determines how the input array is extended
      beyond its boundaries. Default is 'constant'. Behavior for each valid
      value is as follows:
      'reflect' (`d c b a | a b c d | d c b a`)
          The input is extended by reflecting about the edge of the last
          pixel.
      'constant' (`k k k k | a b c d | k k k k`)
          The input is extended by filling all values beyond the edge with
          the same constant value, defined by the `cval` parameter.
      'nearest' (`a a a a | a b c d | d d d d`)
          The input is extended by replicating the last pixel.
      'mirror' (`d c b | a b c d | c b a`)
          The input is extended by reflecting about the center of the last
          pixel.
      'wrap' (`a b c d | a b c d | a b c d`)
          The input is extended by wrapping around to the opposite edge.
    cval: Value to fill past edges of input if `mode` is 'constant'. Default 0.0

  Returns:
    the interpolated value at `point`.
  """
  point = jnp.asarray(point)

  domain_lower, domain_upper = zip(*c.grid.domain)
  domain_lower = jnp.array(domain_lower)
  domain_upper = jnp.array(domain_upper)
  shape = jnp.array(c.grid.shape)
  offset = jnp.array(c.offset)
  # For each dimension `i` in point,
  # The map from `point[i]` to index is linear.
  # index(domain_lower[i]) = -offset[i]
  # index(domain_upper[i]) = shape[i] - offset[i]
  # This is easily vectorized as
  index = (-offset + (point - domain_lower) * shape /
           (domain_upper - domain_lower))

  return jax.scipy.ndimage.map_coordinates(
      c.data, coordinates=index, order=order, mode=mode, cval=cval)