initial_conditions.py 6.07 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Prepare initial conditions for simulations."""
import functools
from typing import Callable, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import filter_utils
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.base import pressure
import numpy as np

# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic

Array = Union[np.ndarray, jax.Array]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions


def wrap_variables(
    var: Sequence[Array],
    grid: grids.Grid,
    bcs: Sequence[BoundaryConditions],
    offsets: Optional[Sequence[Tuple[float, ...]]] = None,
    batch_dim: bool = False,
) -> GridVariableVector:
  """Associates offsets, grid, and boundary conditions with a sequence of arrays."""
  if offsets is None:
    offsets = grid.cell_faces
  def impose_bc(arrays):
    return tuple(
        bc.impose_bc(grids.GridArray(u, offset, grid))
        for u, offset, bc in zip(arrays, offsets, bcs))
  if batch_dim:
    return jax.vmap(impose_bc)(var)
  else:
    return impose_bc(var)


def _log_normal_pdf(x, mode, variance=.25):
  """Unscaled PDF for a log normal given `mode` and log variance 1."""
  mean = jnp.log(mode) + variance
  logx = jnp.log(x)
  return jnp.exp(-(mean - logx)**2 / 2 / variance - logx)


def _max_speed(v):
  return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max()


def filtered_velocity_field(
    rng_key: grids.Array,
    grid: grids.Grid,
    maximum_velocity: float = 1,
    peak_wavenumber: float = 3,
    iterations: int = 3,
) -> GridVariableVector:
  """Create divergence-free velocity fields with appropriate spectral filtering.

  Args:
    rng_key: key for seeding the random initial velocity field.
    grid: the grid on which the velocity field is defined.
    maximum_velocity: the maximum speed in the velocity field.
    peak_wavenumber: the velocity field will be filtered so that the largest
      magnitudes are associated with this wavenumber.
    iterations: the number of repeated pressure projection and normalization
      iterations to apply.
  Returns:
    A divergence free velocity field with the given maximum velocity. Associates
    periodic boundary conditions with the velocity field components.
  """
  # Log normal distribution peaked at `peak_wavenumber`. Note that we have to
  # divide by `k ** (ndim - 1)` to account for the volume of the
  # `ndim - 1`-sphere of values with wavenumber `k`.
  def spectral_density(k):
    return _log_normal_pdf(k, peak_wavenumber) / k ** (grid.ndim - 1)
  # TODO(b/156601712): Switch back to the vmapped implementation of filtering:
  # noise = jax.random.normal(rng_key, (grid.ndim,) + grid.shape)
  # filtered = wrap_velocities(jax.vmap(spectral.filter, (None, 0, None))(
  #     spectral_density, noise, grid), grid)
  keys = jax.random.split(rng_key, grid.ndim)
  velocity_components = []
  boundary_conditions = []
  for k in keys:
    noise = jax.random.normal(k, grid.shape)
    velocity_components.append(
        filter_utils.filter(spectral_density, noise, grid))
    boundary_conditions.append(
        boundaries.periodic_boundary_conditions(grid.ndim))
  velocity = wrap_variables(velocity_components, grid, boundary_conditions)

  def project_and_normalize(v: GridVariableVector):
    v = pressure.projection(v)
    vmax = _max_speed(v)
    v = tuple(
        grids.GridVariable(maximum_velocity * u.array / vmax, u.bc) for u in v)
    return v
  # Due to numerical precision issues, we repeatedly normalize and project the
  # velocity field. This ensures that it is divergence-free and achieves the
  # specified maximum velocity.
  return funcutils.repeated(project_and_normalize, iterations)(velocity)


def initial_velocity_field(
    velocity_fns: Tuple[Callable[..., Array], ...],
    grid: grids.Grid,
    velocity_bc: Optional[Sequence[BoundaryConditions]] = None,
    pressure_solve: Callable = pressure.solve_fast_diag,
    iterations: Optional[int] = None,
) -> GridVariableVector:
  """Given velocity functions on arrays, returns the velocity field on the grid.

  Typical usage example:
    grid = cfd.grids.Grid((128, 128))
    x_velocity_fn = lambda x, y: jnp.sin(x) * jnp.cos(y)
    y_velocity_fn = lambda x, y: jnp.zeros_like(x)
    v0 = initial_velocity_field((x_velocity_fn, y_velocity_fn), grid, 5)

  Args:
    velocity_fns: functions for computing each velocity component. These should
      takes the args (x, y, ...) and return an array of the same shape.
    grid: the grid on which the velocity field is defined.
    velocity_bc: the boundary conditions to associate with each velocity
      component. If unspecified, uses periodic boundary conditions.
    pressure_solve: method used to solve pressure projection.
    iterations: if specified, the number of iterations of applied projection
      onto an incompressible velocity field.

  Returns:
    Velocity components defined with expected offsets on the grid.
  """
  if velocity_bc is None:
    velocity_bc = (
        boundaries.periodic_boundary_conditions(grid.ndim),) * grid.ndim
  v = tuple(
      grids.GridVariable(grid.eval_on_mesh(v_fn, offset), bc) for v_fn, offset,
      bc in zip(velocity_fns, grid.cell_faces, velocity_bc))
  if iterations is not None:
    projection = functools.partial(pressure.projection, solve=pressure_solve)
    v = funcutils.repeated(projection, iterations)(v)
  return v