forcings.py 3.47 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Components that apply forcing. See jax_cfd.base.forcings for forcing API."""

from typing import Callable

from typing import Optional, Tuple
import gin
from jax import numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import equations
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.spectral import utils as spectral_utils

ForcingFn = forcings.ForcingFn
ForcingModule = Callable[..., ForcingFn]


def sum_forcings(*forces: ForcingFn) -> ForcingFn:
  """Sum multiple forcing functions."""
  def forcing(v):
    return equations.sum_fields(*[forcing(v) for forcing in forces])
  return forcing


@gin.register
def filtered_linear_forcing(grid: grids.Grid,
                            scale: float,
                            lower_wavenumber: float = 0,
                            upper_wavenumber: float = 4) -> ForcingFn:
  return forcings.filtered_linear_forcing(lower_wavenumber,
                                          upper_wavenumber,
                                          coefficient=scale,
                                          grid=grid)


@gin.register
def linear_forcing(grid: grids.Grid,
                   scale: float) -> ForcingFn:
  return forcings.linear_forcing(grid, scale)


@gin.register
def spectral_kolmogorov_forcing(grid):
  return forcings.kolmogorov_forcing(
      grid, 1.0, k=4, swap_xy=False, offsets=((0.0, 0.0), (0.0, 0.0)))


@gin.register
def vorticity_space_forcing(grid: grids.Grid, forcing_module: ForcingModule):
  forcing_fn = forcing_module(grid, offsets=((0.0, 0.0), (0.0, 0.0)))
  velocity_solve = spectral_utils.vorticity_to_velocity(grid)
  kx, ky = grid.rfft_mesh()
  fft, ifft = jnp.fft.rfft2, jnp.fft.irfft2
  bc = boundaries.periodic_boundary_conditions(grid.ndim)
  offset = (0.0, 0.0)  # TODO(dresdner) do not hard code

  def forcing_fn_ret(vorticity):
    vorticity, = array_utils.split_axis(vorticity, axis=-1)  # channel dim = 1
    v = tuple(
        grids.GridVariable(grids.GridArray(ifft(u), offset, grid), bc)
        for u in velocity_solve(fft(vorticity)))
    fhatu, fhatv = tuple(fft(u) for u in forcing_fn(v))
    fhat_vorticity = 2j * jnp.pi * (fhatv * kx - fhatu * ky)
    return ifft(fhat_vorticity)

  return forcing_fn_ret


@gin.register
def kolmogorov_forcing(grid: grids.Grid,  # pylint: disable=missing-function-docstring
                       scale: float = 0,
                       wavenumber: int = 2,
                       linear_coefficient: float = 0,
                       swap_xy: bool = False,
                       offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
                       ) -> ForcingFn:
  force_fn = forcings.kolmogorov_forcing(
      grid, scale, wavenumber, swap_xy, offsets=offsets)
  if linear_coefficient != 0:
    linear_force_fn = forcings.linear_forcing(grid, linear_coefficient)
    force_fn = forcings.sum_forcings(force_fn, linear_force_fn)
  return force_fn


@gin.register
def taylor_green_forcing(grid: grids.Grid,
                         scale: float = 0,
                         wavenumber: int = 2,
                         linear_coefficient: float = 0) -> ForcingFn:
  force_fn = forcings.taylor_green_forcing(grid, scale, wavenumber)
  if linear_coefficient != 0:
    linear_force_fn = forcings.linear_forcing(grid, linear_coefficient)
    force_fn = forcings.sum_forcings(force_fn, linear_force_fn)
  return force_fn


@gin.register
def no_forcing(grid: grids.Grid) -> ForcingFn:
  return forcings.no_forcing(grid)