pressures.py 905 Bytes
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""Models for pressure solvers.

All modules are functions that return `pressure_solve` method that has the same
signature as baseline methods e.g. `pressure.solve_fast_diag`.
"""
import functools
from typing import Callable, Optional

import gin

from jax_cfd.base import grids
from jax_cfd.base import pressure


GridArray = grids.GridArray
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
PressureSolveFn = Callable[
    [GridVariableVector, Optional[GridVariable]], GridArray]
PressureModule = Callable[..., PressureSolveFn]


@gin.register
def fast_diagonalization(grid, dt, physics_specs):
  del grid, dt, physics_specs  # unused.
  return pressure.solve_fast_diag


@gin.register
def conjugate_gradient(grid, dt, physics_specs, atol=1e-5, maxiter=32):
  del grid, dt, physics_specs  # unused.
  return functools.partial(pressure.solve_cg, atol=atol, maxiter=maxiter)