pressure.py 6.36 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for computing and applying pressure."""

from typing import Callable, Optional

import jax.numpy as jnp
import jax.scipy.sparse.linalg

from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import fast_diagonalization
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

Array = grids.Array
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions

# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic


# TODO(pnorgaard) Implement bicgstab for non-symmetric operators


def _rhs_transform(
    u: grids.GridArray,
    bc: boundaries.BoundaryConditions,
) -> Array:
  """Transform the RHS of pressure projection equation for stability.

  In case of poisson equation, the kernel is subtracted from RHS for stability.

  Args:
    u: a GridArray that solves ∇²x = u.
    bc: specifies boundary of x.

  Returns:
    u' s.t. u = u' + kernel of the laplacian.
  """
  u_data = u.data
  for axis in range(u.grid.ndim):
    if bc.types[axis][0] == boundaries.BCType.NEUMANN and bc.types[axis][
        1] == boundaries.BCType.NEUMANN:
      # if all sides are neumann, poisson solution has a kernel of constant
      # functions. We substact the mean to ensure consistency.
      u_data = u_data - jnp.mean(u_data)
  return u_data


def solve_cg(
    v: GridVariableVector,
    q0: GridVariable,
    pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None,
    rtol: float = 1e-6,
    atol: float = 1e-6,
    maxiter: Optional[int] = None) -> GridArray:
  """Conjugate gradient solve for the pressure such that continuity is enforced.

  Returns a pressure correction `q` such that `div(v - grad(q)) == 0`.

  The relationship between `q` and our actual pressure estimate is given by
  `p = q * density / dt`.

  Args:
    v: the velocity field.
    q0: an initial value, or "guess" for the pressure correction. A common
      choice is the correction from the previous time step. Also specifies the
      boundary conditions on `q`.
    pressure_bc: the boundary condition to assign to pressure. If None,
      boundary condition is infered from velocity.
    rtol: relative tolerance for convergence.
    atol: absolute tolerance for convergence.
    maxiter: optional int, the maximum number of iterations to perform.

  Returns:
    A pressure correction `q` such that `div(v - grad(q))` is zero.
  """
  # TODO(jamieas): add functionality for non-uniform density.
  rhs = fd.divergence(v)

  if pressure_bc is None:
    pressure_bc = boundaries.get_pressure_bc_from_velocity(v)

  def laplacian_with_bcs(array: GridArray) -> GridArray:
    variable = pressure_bc.impose_bc(array)
    return fd.laplacian(variable)

  q, _ = jax.scipy.sparse.linalg.cg(
      laplacian_with_bcs,
      rhs,
      x0=q0.array,
      tol=rtol,
      atol=atol,
      maxiter=maxiter)
  return q


def solve_fast_diag(
    v: GridVariableVector,
    q0: Optional[grids.GridArray] = None,
    pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None,
    implementation: Optional[str] = None,
) -> grids.GridArray:
  """Solve for pressure using the fast diagonalization approach.

  To support backward compatibility, if the pressure_bc are not provided and
  velocity has all periodic boundaries, pressure_bc are assigned to be periodic.

  Args:
    v: a tuple of velocity values for each direction.
    q0: the starting guess for the pressure.
    pressure_bc: the boundary condition to assign to pressure. If None,
      boundary condition is infered from velocity.
    implementation: how to implement fast diagonalization.
      For non-periodic BCs will automatically be matmul.


  Returns:
    A solution to the PPE equation.
  """
  del q0  # unused
  if pressure_bc is None:
    pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
  if boundaries.has_all_periodic_boundary_conditions(*v):
    circulant = True
  else:
    circulant = False
    # only matmul implementation supports non-circulant matrices
    implementation = 'matmul'
  rhs = fd.divergence(v)
  laplacians = array_utils.laplacian_matrix_w_boundaries(
      rhs.grid, rhs.offset, pressure_bc)
  rhs_transformed = _rhs_transform(rhs, pressure_bc)
  pinv = fast_diagonalization.pseudoinverse(
      laplacians,
      rhs_transformed.dtype,
      hermitian=True,
      circulant=circulant,
      implementation=implementation)
  return grids.GridArray(pinv(rhs_transformed), rhs.offset, rhs.grid)


def solve_fast_diag_channel_flow(
    v: GridVariableVector,
    q0: Optional[grids.GridArray] = None,
    pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None,
) -> grids.GridArray:
  """Applies solve_fast_diag for channel flow.

  Args:
    v: a tuple of velocity values for each direction.
    q0: the starting guess for the pressure.
    pressure_bc: the boundary condition to assign to pressure. If None,
      boundary condition is infered from velocity.

  Returns:
    A solutiion to the PPE equation.
  """
  if pressure_bc is None:
    pressure_bc = boundaries.get_pressure_bc_from_velocity(v)
  return solve_fast_diag(v, q0, pressure_bc, implementation='matmul')


def projection(
    v: GridVariableVector,
    solve: Callable = solve_fast_diag,
) -> GridVariableVector:
  """Apply pressure projection to make a velocity field divergence free."""
  grid = grids.consistent_grid(*v)
  pressure_bc = boundaries.get_pressure_bc_from_velocity(v)

  q0 = grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid)
  q0 = pressure_bc.impose_bc(q0)

  q = solve(v, q0, pressure_bc)
  q = pressure_bc.impose_bc(q)
  q_grad = fd.forward_difference(q)
  v_projected = tuple(
      u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad))

  return v_projected