advections.py 2.73 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Models for advection and convection components."""
import functools

from typing import Callable, Optional
import gin
from jax_cfd.base import advection
from jax_cfd.base import grids
from jax_cfd.ml import interpolations
from jax_cfd.ml import physics_specifications


GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
InterpolationModule = interpolations.InterpolationModule
AdvectFn = Callable[[GridVariable, GridVariableVector, float], GridArray]
AdvectionModule = Callable[..., AdvectFn]
ConvectFn = Callable[[GridVariableVector], GridArrayVector]
ConvectionModule = Callable[..., ConvectFn]


@gin.register
def modular_advection(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
    c_interpolation_module: InterpolationModule = interpolations.upwind,
    u_interpolation_module: InterpolationModule = interpolations.linear,
    **kwargs
) -> AdvectFn:
  """Modular advection module based on `advection_diffusion.advect_general`."""
  c_interpolate_fn = c_interpolation_module(grid, dt, physics_specs, **kwargs)
  u_interpolate_fn = u_interpolation_module(grid, dt, physics_specs, **kwargs)

  def advect(
      c: GridVariable,
      v: GridVariableVector,
      dt: Optional[float] = None
  ) -> GridArray:
    return advection.advect_general(
        c, v, u_interpolate_fn, c_interpolate_fn, dt)

  return advect


@gin.register
def modular_self_advection(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
    interpolation_module: InterpolationModule,
    **kwargs
) -> AdvectFn:
  """Modular self advection using a single interpolation module."""
  # TODO(jamieas): Replace this entire function once
  # `single_tower_navier_stokes` is in place.
  interpolate_fn = interpolation_module(grid, dt, physics_specs, **kwargs)
  c_interpolate_fn = functools.partial(interpolate_fn, tag='c')
  u_interpolate_fn = functools.partial(interpolate_fn, tag='u')

  def advect(
      c: GridVariable,
      v: GridVariableVector,
      dt: Optional[float] = None
  ) -> GridArray:
    return advection.advect_general(
        c, v, u_interpolate_fn, c_interpolate_fn, dt)

  return advect


@gin.register
def self_advection(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.NavierStokesPhysicsSpecs,
    advection_module: AdvectionModule = modular_advection,
    **kwargs
) -> ConvectFn:
  """Convection module based on simultaneous self-advection of velocities."""
  advect_fn = advection_module(grid, dt, physics_specs, **kwargs)

  def convect(v: GridVariableVector) -> GridArrayVector:
    return tuple(advect_fn(u, v, dt) for u in v)

  return convect