finite_differences.py 7.42 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for approximating derivatives.

Finite difference methods operate on GridVariable and return GridArray.
Evaluating finite differences requires boundary conditions, which are defined
for a GridVariable. But the operation of taking a derivative makes the boundary
condition undefined, which is why the return type is GridArray.

For example, if the variable c has the boundary condition c_b = 0, and we take
the derivate dc/dx, then it is unclear what the boundary condition on dc/dx
should be. So programmatically, after taking finite differences and doing
operations, the user has to explicitly assign boundary conditions to the result.

Example:
  c = GridVariable(c_array, c_boundary_condition)
  dcdx = finite_differences.forward_difference(c)  # returns GridArray
  c_new = c + dt * (-velocity * dcdx)  # operations on GridArrays
  c_new = GridVariable(c_new, c_boundary_condition)  # assocaite BCs
"""

import typing
from typing import Optional, Sequence, Tuple
from jax_cfd.base import grids
from jax_cfd.base import interpolation
import numpy as np

GridArray = grids.GridArray
GridVariable = grids.GridVariable
GridArrayTensor = grids.GridArrayTensor


def stencil_sum(*arrays: GridArray) -> GridArray:
  """Sum arrays across a stencil, with an averaged offset."""
  # pylint: disable=line-too-long
  offset = grids.averaged_offset(*arrays)
  # pytype appears to have a bad type signature for sum():
  # Built-in function sum was called with the wrong arguments [wrong-arg-types]
  #          Expected: (iterable: Iterable[complex])
  #   Actually passed: (iterable: Generator[Union[jax.interpreters.xla.DeviceArray, numpy.ndarray], Any, None])
  result = sum(array.data for array in arrays)  # type: ignore
  grid = grids.consistent_grid(*arrays)
  return grids.GridArray(result, offset, grid)


# incompatible with typing.overload
# pylint: disable=pointless-statement
# pylint: disable=function-redefined
# pylint: disable=unused-argument


@typing.overload
def central_difference(u: GridVariable, axis: int) -> GridArray:
  ...


@typing.overload
def central_difference(
    u: GridVariable, axis: Optional[Tuple[int, ...]]) -> Tuple[GridArray, ...]:
  ...


def central_difference(u, axis=None):
  """Approximates grads with central differences."""
  if axis is None:
    axis = range(u.grid.ndim)
  if not isinstance(axis, int):
    return tuple(central_difference(u, a) for a in axis)  # pytype: disable=wrong-arg-types  # always-use-return-annotations
  diff = stencil_sum(u.shift(+1, axis), -u.shift(-1, axis))
  return diff / (2 * u.grid.step[axis])


@typing.overload
def backward_difference(u: GridVariable, axis: int) -> GridArray:
  ...


@typing.overload
def backward_difference(
    u: GridVariable, axis: Optional[Tuple[int, ...]]) -> Tuple[GridArray, ...]:
  ...


def backward_difference(u, axis=None):
  """Approximates grads with finite differences in the backward direction."""
  if axis is None:
    axis = range(u.grid.ndim)
  if not isinstance(axis, int):
    return tuple(backward_difference(u, a) for a in axis)  # pytype: disable=wrong-arg-types  # always-use-return-annotations
  diff = stencil_sum(u.array, -u.shift(-1, axis))
  return diff / u.grid.step[axis]


@typing.overload
def forward_difference(u: GridVariable, axis: int) -> GridArray:
  ...


@typing.overload
def forward_difference(
    u: GridVariable,
    axis: Optional[Tuple[int, ...]] = ...) -> Tuple[GridArray, ...]:
  ...


def forward_difference(u, axis=None):
  """Approximates grads with finite differences in the forward direction."""
  if axis is None:
    axis = range(u.grid.ndim)
  if not isinstance(axis, int):
    return tuple(forward_difference(u, a) for a in axis)  # pytype: disable=wrong-arg-types  # always-use-return-annotations
  diff = stencil_sum(u.shift(+1, axis), -u.array)
  return diff / u.grid.step[axis]


def laplacian(u: GridVariable) -> GridArray:
  """Approximates the Laplacian of `u`."""
  scales = np.square(1 / np.array(u.grid.step, dtype=u.dtype))
  result = -2 * u.array * np.sum(scales)
  for axis in range(u.grid.ndim):
    result += stencil_sum(u.shift(-1, axis), u.shift(+1, axis)) * scales[axis]
  return result


def divergence(v: Sequence[GridVariable]) -> GridArray:
  """Approximates the divergence of `v` using backward differences."""
  grid = grids.consistent_grid(*v)
  if len(v) != grid.ndim:
    raise ValueError('The length of `v` must be equal to `grid.ndim`.'
                     f'Expected length {grid.ndim}; got {len(v)}.')
  differences = [backward_difference(u, axis) for axis, u in enumerate(v)]
  return sum(differences)


def centered_divergence(v: Sequence[GridVariable]) -> GridArray:
  """Approximates the divergence of `v` using centered differences."""
  grid = grids.consistent_grid(*v)
  if len(v) != grid.ndim:
    raise ValueError('The length of `v` must be equal to `grid.ndim`.'
                     f'Expected length {grid.ndim}; got {len(v)}.')
  differences = [central_difference(u, axis) for axis, u in enumerate(v)]
  return sum(differences)


@typing.overload
def gradient_tensor(v: GridVariable) -> GridArrayTensor:
  ...


@typing.overload
def gradient_tensor(v: Sequence[GridVariable]) -> GridArrayTensor:
  ...


def gradient_tensor(v):
  """Approximates the cell-centered gradient of `v`."""
  if not isinstance(v, GridVariable):
    return GridArrayTensor(np.stack([gradient_tensor(u) for u in v], axis=-1))  # pytype: disable=wrong-arg-types  # always-use-return-annotations
  grad = []
  for axis in range(v.grid.ndim):
    offset = v.offset[axis]
    if offset == 0:
      derivative = forward_difference(v, axis)
    elif offset == 1:
      derivative = backward_difference(v, axis)
    elif offset == 0.5:
      v_centered = interpolation.linear(v, v.grid.cell_center)
      derivative = central_difference(v_centered, axis)
    else:
      raise ValueError(f'expected offset values in {{0, 0.5, 1}}, got {offset}')
    grad.append(derivative)
  return GridArrayTensor(grad)


def curl_2d(v: Sequence[GridVariable]) -> GridArray:
  """Approximates the curl of `v` in 2D using forward differences."""
  if len(v) != 2:
    raise ValueError(f'Length of `v` is not 2: {len(v)}')
  grid = grids.consistent_grid(*v)
  if grid.ndim != 2:
    raise ValueError(f'Grid dimensionality is not 2: {grid.ndim}')
  return forward_difference(v[1], axis=0) - forward_difference(v[0], axis=1)


def curl_3d(
    v: Sequence[GridVariable]) -> Tuple[GridArray, GridArray, GridArray]:
  """Approximates the curl of `v` in 2D using forward differences."""
  if len(v) != 3:
    raise ValueError(f'Length of `v` is not 3: {len(v)}')
  grid = grids.consistent_grid(*v)
  if grid.ndim != 3:
    raise ValueError(f'Grid dimensionality is not 3: {grid.ndim}')
  curl_x = (forward_difference(v[2], axis=1) - forward_difference(v[1], axis=2))
  curl_y = (forward_difference(v[0], axis=2) - forward_difference(v[2], axis=0))
  curl_z = (forward_difference(v[1], axis=0) - forward_difference(v[0], axis=1))
  return (curl_x, curl_y, curl_z)