advection.py 2.57 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for functionality related to advection."""

from typing import Optional, Tuple

from jax_cfd.base import boundaries
from jax_cfd.base import finite_differences as fd
from jax_cfd.base import grids

GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector


def advect_linear(c: GridVariable,
                  v: GridVariableVector,
                  dt: Optional[float] = None) -> GridArray:
  """Computes advection for collocated scalar `c` with velocity `v`."""
  del dt
  flux_bc = [
      boundaries.get_advection_flux_bc_from_velocity_and_scalar(u, c, direction)
      for direction, u in enumerate(v)
  ]
  flux = tuple(flux_bc[axis].impose_bc(c.array * v[axis].array)
               for axis in range(c.grid.ndim))
  return -fd.centered_divergence(flux)


def _velocities_to_flux(v: GridVariableVector) -> Tuple[GridVariableVector]:
  """Computes the cell-centered convective flux for a velocity field.

  This is the flux associated with the nonlinear term `vv` for velocity `v`.
  The boundary condition on the flux is inherited from `v`.

  Args:
    v: velocity vector.

  Returns:
    A tuple of tuples `flux` of `GridVariable`s with the values `v[i]*v[j]`
  """
  ndim = len(v)
  flux = [tuple() for _ in range(ndim)]
  ndim = len(v)
  flux = [tuple() for _ in range(ndim)]
  for i in range(ndim):
    for j in range(ndim):
      if i <= j:
        bc = boundaries.get_advection_flux_bc_from_velocity_and_scalar(
            v[j], v[i], j)
        flux[i] += (bc.impose_bc(v[i].array * v[j].array),)
      else:
        flux[i] += (flux[j][i],)
  return tuple(flux)


def convect_linear(v: GridVariableVector) -> GridArrayVector:
  """Computes convection/self-advection of the velocity field `v`.

  Args:
    v: velocity vector.

  Returns:
    A tuple containing the time derivative of each component of `v` due to
    convection.
  """
  fluxes = _velocities_to_flux(v)
  return tuple(-fd.centered_divergence(flux) for flux in fluxes)