diffusions.py 2.81 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Models for diffusion components.

All modules are functions that return `DiffuseFn` or `DiffusionSolveFn` method.
The two types of diffusion modules should be used with corresponding explicit
and implicit navier-stokes solvers.

An example explicit diffusion module:

```python
def diffusion_module(dt, module_params, **kwargs):
  pre_compute_values = f(dt, module_params)
  def diffuse(c: grids.GridVariable, nu: float, grid: grids.Grid, dt: float):
    # compute time derivative due to diffusion.
    return dc_dt

  return diffuse
```
"""
import functools
from typing import Callable, Optional
import gin
import haiku as hk
from jax_cfd.base import diffusion
from jax_cfd.base import grids
from jax_cfd.base import subgrid_models

from jax_cfd.ml import viscosities


GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
DiffuseFn = Callable[[GridVariable, float], GridArray]
DiffusionSolveFn = Callable[[GridVariableVector, float, float],
                            GridVariableVector]
DiffuseModule = Callable[..., DiffuseFn]
DiffusionSolveModule = Callable[..., DiffusionSolveFn]
ViscosityModule = viscosities.ViscosityModule


# TODO(shoyer): stop deleting unrecognized **kwargs. This is really error-prone!


@gin.register(denylist=("grid", "dt", "physics_specs"))
def diffuse(grid, dt, physics_specs) -> DiffuseFn:
  del grid, dt, physics_specs  # unused.
  return diffusion.diffuse


@gin.register(denylist=("grid", "dt", "physics_specs"))
def solve_fast_diag(
    grid,
    dt,
    physics_specs,
    implementation=None
) -> DiffusionSolveFn:
  del grid, dt, physics_specs  # unused.
  return functools.partial(
      diffusion.solve_fast_diag, implementation=implementation)


@gin.register(denylist=("grid", "dt", "physics_specs"))
def solve_cg(
    grid,
    dt,
    physics_specs,
    atol: float = 1e-5,
    rtol: float = 1e-5,
    maxiter: Optional[int] = 64,
) -> DiffusionSolveFn:
  """Returns conjugate gradient solve method."""
  del grid, dt, physics_specs  # unused.
  return functools.partial(
      diffusion.solve_cg, atol=atol, rtol=rtol, maxiter=maxiter)


@gin.register(denylist=("grid", "dt", "physics_specs"))
def implicit_evm_solve_with_diffusion(
    grid,
    dt,
    physics_specs,
    viscosity_module: ViscosityModule = viscosities.eddy_viscosity_model,
    atol: float = 1e-5,
    maxiter: Optional[int] = 64,
) -> DiffusionSolveFn:
  """Returns solve_diffusion method that also includes a viscosity model."""
  evm_model = viscosity_module(grid, dt, physics_specs)
  cg_kwargs = dict(atol=atol, maxiter=maxiter)
  diffusion_solve = functools.partial(
      subgrid_models.implicit_evm_solve_with_diffusion,
      configured_evm_model=evm_model,
      cg_kwargs=cg_kwargs)
  return hk.to_module(diffusion_solve)(name="diffusion_solve")