"example/2_gemm_xdl_bias_add/README.md" did not exist on "457c024d608c9b855775cc014a630c7e0d30710c"
advection.py 16.1 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for functionality related to advection."""

from typing import Optional, Tuple

import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids
from jax_cfd.base import interpolation

GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationFn = interpolation.InterpolationFn
# TODO(dkochkov) Consider testing if we need operator splitting methods.


def _advect_aligned(cs: GridVariableVector, v: GridVariableVector) -> GridArray:
  """Computes fluxes and the associated advection for aligned `cs` and `v`.

  The values `cs` should consist of a single quantity `c` that has been
  interpolated to the offset of the components of `v`. The components of `v` and
  `cs` should be located at the faces of a single (possibly offset) grid cell.
  We compute the advection as the divergence of the flux on this control volume.

  The boundary condition on the flux is inherited from the scalar quantity `c`.

  A typical example in three dimensions would have

  ```
  cs[0].offset == v[0].offset == (1., .5, .5)
  cs[1].offset == v[1].offset == (.5, 1., .5)
  cs[2].offset == v[2].offset == (.5, .5, 1.)
  ```

  In this case, the returned advection term would have offset `(.5, .5, .5)`.

  Args:
    cs: a sequence of `GridArray`s; a single value `c` that has been
      interpolated so that it is aligned with each component of `v`.
    v: a sequence of `GridArrays` describing a velocity field. Should be defined
      on the same Grid as cs.

  Returns:
    An `GridArray` containing the time derivative of `c` due to advection by
    `v`.

  Raises:
    ValueError: `cs` and `v` have different numbers of components.
    AlignmentError: if the components of `cs` are not aligned with those of `v`.
  """
  # TODO(jamieas): add more sophisticated alignment checks, ensuring that the
  # values are located on the faces of a control volume.
  if len(cs) != len(v):
    raise ValueError('`cs` and `v` must have the same length;'
                     f'got {len(cs)} vs. {len(v)}.')
  flux = tuple(c.array * u.array for c, u in zip(cs, v))
  bcs = tuple(
      boundaries.get_advection_flux_bc_from_velocity_and_scalar(v[i], cs[i], i)
      for i in range(len(v)))
  flux = tuple(bc.impose_bc(f) for f, bc in zip(flux, bcs))
  return -fd.divergence(flux)


def advect_general(
    c: GridVariable,
    v: GridVariableVector,
    u_interpolation_fn: InterpolationFn,
    c_interpolation_fn: InterpolationFn,
    dt: Optional[float] = None) -> GridArray:
  """Computes advection of a scalar quantity `c` by the velocity field `v`.

  This function follows the following procedure:

    1. Interpolate each component of `v` to the corresponding face of the
       control volume centered on `c`.
    2. Interpolate `c` to the same control volume faces.
    3. Compute the flux `cu` using the aligned values.
    4. Set the boundary condition on flux, which is inhereited from `c`.
    5. Return the negative divergence of the flux.

  Args:
    c: the quantity to be transported.
    v: a velocity field. Should be defined on the same Grid as c.
    u_interpolation_fn: method for interpolating velocity field `v`.
    c_interpolation_fn: method for interpolating scalar field `c`.
    dt: unused time-step.

  Returns:
    The time derivative of `c` due to advection by `v`.
  """
  if not boundaries.has_all_periodic_boundary_conditions(c):
    raise NotImplementedError(
        'Non-periodic boundary conditions are not implemented.')
  target_offsets = grids.control_volume_offsets(c)
  aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt)
                    for u, target_offset in zip(v, target_offsets))
  aligned_c = tuple(c_interpolation_fn(c, target_offset, aligned_v, dt)
                    for target_offset in target_offsets)
  return _advect_aligned(aligned_c, aligned_v)


def advect_linear(c: GridVariable,
                  v: GridVariableVector,
                  dt: Optional[float] = None) -> GridArray:
  """Computes advection using linear interpolations."""
  return advect_general(c, v, interpolation.linear, interpolation.linear, dt)


def advect_upwind(c: GridVariable,
                  v: GridVariableVector,
                  dt: Optional[float] = None) -> GridArray:
  """Computes advection using first-order upwind interpolation on `c`."""
  return advect_general(c, v, interpolation.linear, interpolation.upwind, dt)


def _align_velocities(v: GridVariableVector) -> Tuple[GridVariableVector]:
  """Returns interpolated components of `v` needed for convection.

  Args:
    v: a velocity field.

  Returns:
    A d-tuple of d-tuples of `GridVariable`s `aligned_v`, where `d = len(v)`.
    The entry `aligned_v[i][j]` is the component `v[i]` after interpolation to
    the appropriate face of the control volume centered around `v[j]`.
  """
  grid = grids.consistent_grid(*v)
  offsets = tuple(grids.control_volume_offsets(u) for u in v)
  aligned_v = tuple(
      tuple(interpolation.linear(v[i], offsets[i][j])
            for j in range(grid.ndim))
      for i in range(grid.ndim))
  return aligned_v


def _velocities_to_flux(
    aligned_v: Tuple[GridVariableVector]) -> Tuple[GridVariableVector]:
  """Computes the fluxes across the control volume faces for a velocity field.

  This is the flux associated with the nonlinear term `vv` for velocity `v`.
  The boundary condition on the flux is inherited from `v`.

  Args:
    aligned_v: a d-tuple of d-tuples of `GridVariable`s such that the entry
    `aligned_v[i][j]` is the component `v[i]` after interpolation to
    the appropriate face of the control volume centered around `v[j]`. This is
    the output of `_align_velocities`.

  Returns:
    A tuple of tuples `flux` of `GridVariable`s with the same structure as
    `aligned_v`. The entry `flux[i][j]` is `aligned_v[i][j] * aligned_v[j][i]`.
  """
  ndim = len(aligned_v)
  flux = [tuple() for _ in range(ndim)]
  for i in range(ndim):
    for j in range(ndim):
      if i <= j:
        bc = boundaries.get_advection_flux_bc_from_velocity_and_scalar(
            aligned_v[j][i], aligned_v[i][j], j)
        flux[i] += (bc.impose_bc(aligned_v[i][j].array *
                                 aligned_v[j][i].array),)
      else:
        flux[i] += (flux[j][i],)
  return tuple(flux)


def convect_linear(v: GridVariableVector) -> GridArrayVector:
  """Computes convection/self-advection of the velocity field `v`.

  This function is conceptually equivalent to

  ```
  def convect_linear(v, grid):
    return tuple(advect_linear(u, v, grid) for u in v)
  ```

  However, there are several optimizations to avoid duplicate operations.

  Args:
    v: a velocity field.

  Returns:
    A tuple containing the time derivative of each component of `v` due to
    convection.
  """
  # TODO(jamieas): consider a more efficient vectorization of this function.
  # TODO(jamieas): incorporate variable density.
  aligned_v = _align_velocities(v)
  fluxes = _velocities_to_flux(aligned_v)
  return tuple(-fd.divergence(flux) for flux in fluxes)


def advect_van_leer(
    c: GridVariable,
    v: GridVariableVector,
    dt: float,
    mode: str = boundaries.Padding.MIRROR,
) -> GridArray:
  """Computes advection of a scalar quantity `c` by the velocity field `v`.

  Implements Van-Leer flux limiting scheme that uses second order accurate
  approximation of fluxes for smooth regions of the solution. This scheme is
  total variation diminishing (TVD). For regions with high gradients flux
  limitor transformes the scheme into a first order method. For [1] for
  reference. This function follows the following procedure:

    1. Shifts c to offset < 1 if necessary.
    2. Scalar c now has a well defined right-hand (upwind) value.
    3. Computes upwind flux for each direction.
    4. Computes van leer flux limiter:
      a. Use the shifted c to interpolate each component of `v` to the
        right-hand (upwind) face of the control volume centered on  `c`.
      b. Compute the ratio of successive gradients:
        In nonperiodic case, the value outside the boundary is not defined.
        Mode is used to interpolate past the boundary.
      c. Compute flux limiter function.
      d. Computes higher order flux correction.
    5. Combines fluxes and assigns flux boundary condition.
    6. Computes the negative divergence of fluxes.
    7. Shifts the computed values back to original offset of c.

  Args:
    c: the quantity to be transported.
    v: a velocity field. Should be defined on the same Grid as c.
    dt: time step for which this scheme is TVD and second order accurate
      in time.
    mode: For non-periodic BC, specifies extrapolation of values beyond the
      boundary, which is used by nonlinear interpolation.

  Returns:
    The time derivative of `c` due to advection by `v`.

  #### References

  [1]:  MIT 18.336 spring 2009 Finite Volume Methods Lecture 19.
        go/mit-18.336-finite_volume_methods-19
  [2]:
    www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf

  """
  # TODO(dkochkov) reimplement this using apply_limiter method.
  c_left_var = c
  # if the offset is 1., shift by 1 to offset 0.
  # otherwise c_right is not defined.
  for ax in range(c.grid.ndim):
    # int(c.offset[ax] % 1 - c.offset[ax]) = -1 if c.offset[ax] is 1 else
    # int(c.offset[ax] % 1 - c.offset[ax]) = 0.
    # i.e. this shifts the 1 aligned data to 0 offset, the rest is unchanged.
    c_left_var = c.bc.impose_bc(
        c_left_var.shift(int(c.offset[ax] % 1 - c.offset[ax]), axis=ax))
  offsets = grids.control_volume_offsets(c_left_var)
  # if c offset is 0, aligned_v is at 0.5.
  # if c offset is at .5, aligned_v is at 1.
  aligned_v = tuple(interpolation.linear(u, offset)
                    for u, offset in zip(v, offsets))
  flux = []
  # Assign flux boundary condition
  flux_bc = [
      boundaries.get_advection_flux_bc_from_velocity_and_scalar(u, c, direction)
      for direction, u in enumerate(v)
  ]
  # first, compute upwind flux.
  for axis, u in enumerate(aligned_v):
    c_center = c_left_var.data
    # by shifting c_left + 1, c_right is well-defined.
    c_right = c_left_var.shift(+1, axis=axis).data
    upwind_flux = grids.applied(jnp.where)(
        u.array > 0, u.array * c_center, u.array * c_right)
    flux.append(upwind_flux)
  # next, compute van_leer correction.
  for axis, (u, h) in enumerate(zip(aligned_v, c.grid.step)):
    u = u.bc.shift(u.array, int(u.offset[axis] % 1 - u.offset[axis]), axis=axis)
    # c is put to offset .5 or 1.
    c_center_arr = c.shift(int(1 - c.offset[ax]), axis=ax)
    # if c offset is 1, u offset is .5.
    # if c offset is .5, u offset is 0.
    # u_i is always on the left of c_center_var_i
    c_center = c_center_arr.data
    # shift -1 are well defined now
    # shift +1 is not well defined for c offset 1 because then c(wall + 1) is
    # not defined.
    # However, the flux that uses c(wall + 1) offset gets overridden anyways
    # when flux boundary condition is overridden.
    # Thus, any mode can be used here.
    c_right = c.bc.shift(c_center_arr, +1, axis=axis, mode=mode).data
    c_left = c.bc.shift(c_center_arr, -1, axis=axis).data
    # shift -2 is tricky:
    # It is well defined if c is periodic.
    # Else, c(-1) or c(-1.5) are not defined.
    # Then, mode is used to interpolate the values.
    c_left_left = c.bc.shift(
        c_center_arr, -2, axis, mode=mode).data

    numerator_positive = c_left - c_left_left
    numerator_negative = c_right - c_center
    numerator = grids.applied(jnp.where)(u > 0, numerator_positive,
                                         numerator_negative)
    denominator = grids.GridArray(c_center - c_left, u.offset, u.grid)
    # We want to calculate denominator / (abs(denominator) + abs(numerator))
    # To make it differentiable, it needs to be done in stages.

    # ensures that there is no division by 0
    phi_van_leer_denominator_avoid_nans = grids.applied(jnp.where)(
        abs(denominator) > 0, (abs(denominator) + abs(numerator)), 1.)

    phi_van_leer_denominator_inv = denominator / phi_van_leer_denominator_avoid_nans

    phi_van_leer = numerator * (grids.applied(jnp.sign)(denominator) +
                                grids.applied(jnp.sign)
                                (numerator)) * phi_van_leer_denominator_inv
    abs_velocity = abs(u)
    courant_numbers = (dt / h) * abs_velocity
    pre_factor = 0.5 * (1 - courant_numbers) * abs_velocity
    flux_correction = pre_factor * phi_van_leer
    # Shift back onto original offset.
    flux_correction = flux_bc[axis].shift(
        flux_correction, int(offsets[axis][axis] - u.offset[axis]), axis=axis)
    flux[axis] += flux_correction
  flux = tuple(flux_bc[axis].impose_bc(f) for axis, f in enumerate(flux))
  advection = -fd.divergence(flux)
  # shift the variable back onto the original offset
  for ax in range(c.grid.ndim):
    advection = c.bc.shift(
        advection, -int(c.offset[ax] % 1 - c.offset[ax]), axis=ax)
  return advection


def advect_step_semilagrangian(
    c: GridVariable,
    v: GridVariableVector,
    dt: float
) -> GridVariable:
  """Semi-Lagrangian advection of a scalar quantity.

  Note that unlike the other advection functions, this function returns values
  at the next time-step, not the time derivative.

  Args:
    c: the quantity to be transported.
    v: a velocity field. Should be defined on the same Grid as c.
    dt: desired time-step.

  Returns:
    Advected quantity at the next time-step -- *not* the time derivative.
  """
  # Reference: "Learning to control PDEs with Differentiable Physics"
  # https://openreview.net/pdf?id=HyeSin4FPB (see Appendix A)
  grid = grids.consistent_grid(c, *v)

  # TODO(shoyer) Enable lower domains != 0 for this function.
  # Hint: indices = [
  #     -o + (x - l) * n / (u - l)
  #     for (l, u), o, x, n in zip(grid.domain, c.offset, coords, grid.shape)
  # ]
  if not all(d[0] == 0 for d in grid.domain):
    raise ValueError(
        f'Grid domains currently must start at zero. Found {grid.domain}')
  coords = [x - dt * interpolation.linear(u, c.offset).data
            for x, u in zip(grid.mesh(c.offset), v)]
  indices = [x / s - o for s, o, x in zip(grid.step, c.offset, coords)]
  if not boundaries.has_all_periodic_boundary_conditions(c):
    raise NotImplementedError('non-periodic BCs not yet supported')
  c_advected = grids.applied(jax.scipy.ndimage.map_coordinates)(
      c.array, indices, order=1, mode='wrap')
  return GridVariable(c_advected, c.bc)


# TODO(dkochkov) Implement advect_with_flux_limiter method.
# TODO(dkochkov) Consider moving `advect_van_leer` to test based on performance.
def advect_van_leer_using_limiters(
    c: GridVariable,
    v: GridVariableVector,
    dt: float
) -> GridArray:
  """Implements Van-Leer advection by applying TVD limiter to Lax-Wendroff."""
  c_interpolation_fn = interpolation.apply_tvd_limiter(
      interpolation.lax_wendroff, limiter=interpolation.van_leer_limiter)
  return advect_general(c, v, interpolation.linear, c_interpolation_fn, dt)


def stable_time_step(max_velocity: float,
                     max_courant_number: float,
                     grid: grids.Grid) -> float:
  """Calculate a stable time step size for explicit advection.

  The calculation is based on the CFL condition for advection.

  Args:
    max_velocity: maximum velocity.
    max_courant_number: the Courant number used to choose the time step. Smaller
      numbers will lead to more stable simulations. Typically this should be in
      the range [0.5, 1).
    grid: a `Grid` object.

  Returns:
    The prescribed time interval.
  """
  dx = min(grid.step)
  return max_courant_number * dx / max_velocity