diffusion.py 7.37 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO(pnorgaard) Implement bicgstab for non-symmetric operators

"""Module for functionality related to diffusion."""
from typing import Optional, Tuple

import jax.numpy as jnp
import jax.scipy.sparse.linalg
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

Array = grids.Array
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector


def diffuse(c: GridVariable, nu: float) -> GridArray:
  """Returns the rate of change in a concentration `c` due to diffusion."""
  return nu * fd.laplacian(c)


def stable_time_step(viscosity: float, grid: grids.Grid) -> float:
  """Calculate a stable time step size for explicit diffusion.

  The calculation is based on analysis of central-time-central-space (CTCS)
  schemes.

  Args:
    viscosity: kinematic visosity
    grid: a `Grid` object.

  Returns:
    The prescribed time interval.
  """
  if viscosity == 0:
    return float('inf')
  dx = min(grid.step)
  ndim = grid.ndim
  return dx ** 2 / (viscosity * 2 ** ndim)


def _subtract_linear_part_dirichlet(
    c_data: Array,
    grid: grids.Grid,
    axis: int,
    offset: Tuple[float, float],
    bc_values: Tuple[float, float],
) -> Array:
  """Transforms c_data such that c_data satisfies dirichlet boundary.

  The function subtracts a linear function from c_data s.t. the returned
  array has homogeneous dirichlet boundaries. Note that this assumes c_data has
  constant dirichlet boundary values.

  Args:
    c_data: right-hand-side of diffusion equation.
    grid: grid object
    axis: axis along which to impose boundary transformation
    offset: offset of the right-hand-side
    bc_values: boundary values along axis

  Returns:
    transformed right-hand-side
  """

  def _update_rhs_along_axis(arr_1d, linear_part):
    arr_1d = arr_1d - linear_part
    return arr_1d

  lower_value, upper_value = bc_values
  y = grid.mesh(offset)[axis][0]
  one_d_grid = grids.Grid((grid.shape[axis],), domain=(grid.domain[axis],))
  y_boundary = boundaries.dirichlet_boundary_conditions(ndim=1)
  y = y_boundary.trim_boundary(grids.GridArray(y, (offset[axis],),
                                               one_d_grid)).data
  domain_length = (grid.domain[axis][1] - grid.domain[axis][0])
  domain_start = grid.domain[axis][0]
  linear_part = lower_value + (upper_value - lower_value) * (
      y - domain_start) / domain_length
  c_data = jnp.apply_along_axis(
      _update_rhs_along_axis, axis, c_data, linear_part)
  return c_data


def _rhs_transform(
    u: grids.GridArray,
    bc: boundaries.BoundaryConditions,
) -> Array:
  """Transforms the RHS of diffusion equation.

  In case of constant dirichlet boundary conditions for heat equation
  the linear term is subtracted. See diffusion.solve_fast_diag.

  Args:
    u: a GridArray that solves ∇²x = ∇²u for x.
    bc: specifies boundary of u.

  Returns:
    u' s.t. u = u' + w where u' has 0 dirichlet bc and w is linear.
  """
  if not isinstance(bc, boundaries.ConstantBoundaryConditions):
    raise NotImplementedError(
        f'transformation cannot be done for this {bc}.')
  u_data = u.data
  for axis in range(u.grid.ndim):
    for i, _ in enumerate(['lower', 'upper']):  # lower and upper boundary
      if bc.types[axis][i] == boundaries.BCType.DIRICHLET:
        bc_values = [0., 0.]
        bc_values[i] = bc.bc_values[axis][i]
        u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset,
                                                 bc_values)
      elif bc.types[axis][i] == boundaries.BCType.NEUMANN:
        if any(bc.bc_values[axis]):
          raise NotImplementedError(
              'transformation is not implemented for inhomogeneous Neumann bc.')
  return u_data


def solve_cg(v: GridVariableVector,
             nu: float,
             dt: float,
             rtol: float = 1e-6,
             atol: float = 1e-6,
             maxiter: Optional[int] = None) -> GridVariableVector:
  """Conjugate gradient solve for diffusion."""
  if not boundaries.has_all_periodic_boundary_conditions(*v):
    raise ValueError('solve_cg() expects periodic BC')

  def solve_component(u: GridVariable) -> GridArray:
    """Solves (1 - ν Δt ∇²) u_{t+1} = u_{tilda} for u_{t+1}."""

    def linear_op(u_new: GridArray) -> GridArray:
      """Linear operator for (1 - ν Δt ∇²) u_{t+1}."""
      u_new = grids.GridVariable(u_new, u.bc)  # get boundary condition from u
      return u_new.array - dt * nu * fd.laplacian(u_new)

    def cg(b: GridArray, x0: GridArray) -> GridArray:
      """Iteratively solves Lx = b. with initial guess x0."""
      x, _ = jax.scipy.sparse.linalg.cg(
          linear_op, b, x0=x0, tol=rtol, atol=atol, maxiter=maxiter)
      return x

    return cg(u.array, u.array)

  return tuple(grids.GridVariable(solve_component(u), u.bc) for u in v)


def solve_fast_diag(
    v: GridVariableVector,
    nu: float,
    dt: float,
    implementation: Optional[str] = None,
) -> GridVariableVector:
  """Solve for diffusion using the fast diagonalization approach."""
  # We reuse eigenvectors from the Laplacian and transform the eigenvalues
  # because this is better conditioned than directly diagonalizing 1 - ν Δt ∇²
  # when ν Δt is small.
  def func(x):
    dt_nu_x = (dt * nu) * x
    return dt_nu_x / (1 - dt_nu_x)

  # Compute (1 - ν Δt ∇²)⁻¹ u as u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u, for less
  # error when ν Δt is small.
  # If dirichlet bc are supplied: only works for dirichlet bc that are linear
  # functions on the boundary. Then u = u' + w where u' has 0 dirichlet bc and
  # w is linear. Then u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u = u +
  # (1 - ν Δt ∇²)⁻¹(ν Δt ∇²)u'. The function _rhs_transform subtracts
  # the linear part s.t. fast_diagonalization solves
  # u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u'.
  v_diffused = list()
  if boundaries.has_all_periodic_boundary_conditions(*v):
    circulant = True
  else:
    circulant = False
    # only matmul implementation supports non-circulant matrices
    implementation = 'matmul'
  for u in v:
    laplacians = array_utils.laplacian_matrix_w_boundaries(
        u.grid, u.offset, u.bc)
    op = fast_diagonalization.transform(
        func,
        laplacians,
        v[0].dtype,
        hermitian=True,
        circulant=circulant,
        implementation=implementation)
    u_interior = u.bc.trim_boundary(u.array)
    u_interior_transformed = _rhs_transform(u_interior, u.bc)
    u_dt_diffused = grids.GridArray(
        op(u_interior_transformed), u_interior.offset, u_interior.grid)
    u_diffused = u_interior + u_dt_diffused
    u_diffused = u.bc.pad_and_impose_bc(u_diffused, offset_to_pad_to=u.offset)
    v_diffused.append(u_diffused)
  return tuple(v_diffused)