initial_conditions.py 3.69 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Prepare initial conditions for simulations."""
from typing import Union

import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import filter_utils
from jax_cfd.base import funcutils
from jax_cfd.base import grids
from jax_cfd.collocated import pressure
import numpy as np

# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic

Array = Union[np.ndarray, jax.Array]
GridArray = grids.GridArray
GridArrayVector = grids.GridArrayVector
GridVariable = grids.GridVariable
GridVariableVector = grids.GridVariableVector
BoundaryConditions = grids.BoundaryConditions


def _log_normal_pdf(x, mode, variance=.25):
  """Unscaled PDF for a log normal given `mode` and log variance 1."""
  mean = jnp.log(mode) + variance
  logx = jnp.log(x)
  return jnp.exp(-(mean - logx)**2 / 2 / variance - logx)


def _max_speed(v):
  return jnp.linalg.norm(jnp.array([u.data for u in v]), axis=0).max()


def filtered_velocity_field(
    rng_key: grids.Array,
    grid: grids.Grid,
    maximum_velocity: float = 1,
    peak_wavenumber: float = 3,
    iterations: int = 3,
) -> GridVariableVector:
  """Create divergence-free velocity fields with appropriate spectral filtering.

  Modified version for collocated variables.

  Args:
    rng_key: key for seeding the random initial velocity field.
    grid: the grid on which the velocity field is defined.
    maximum_velocity: the maximum speed in the velocity field.
    peak_wavenumber: the velocity field will be filtered so that the largest
      magnitudes are associated with this wavenumber.
    iterations: the number of repeated pressure projection and normalization
      iterations to apply.
  Returns:
    A divergence free velocity field with the given maximum velocity. Associates
    periodic boundary conditions with the velocity field components.
  """
  # Log normal distribution peaked at `peak_wavenumber`. Note that we have to
  # divide by `k ** (ndim - 1)` to account for the volume of the
  # `ndim - 1`-sphere of values with wavenumber `k`.
  def spectral_density(k):
    return _log_normal_pdf(k, peak_wavenumber) / k ** (grid.ndim - 1)

  keys = jax.random.split(rng_key, grid.ndim)
  velocity_components = []
  boundary_conditions = []
  for k in keys:
    noise = jax.random.normal(k, grid.shape)
    velocity_components.append(
        filter_utils.filter(spectral_density, noise, grid))
    boundary_conditions.append(
        boundaries.periodic_boundary_conditions(grid.ndim))
  # Place values on cell-centered grid
  velocity = tuple(
      grids.GridVariable(grids.GridArray(u, grid.cell_center, grid), bc)
      for u, bc in zip(velocity_components, boundary_conditions))

  def project_and_normalize(v: GridVariableVector):
    v = pressure.projection(v)
    vmax = _max_speed(v)
    v = tuple(
        grids.GridVariable(maximum_velocity * u.array / vmax, u.bc) for u in v)
    return v
  # Due to numerical precision issues, we repeatedly normalize and project the
  # velocity field. This ensures that it is divergence-free and achieves the
  # specified maximum velocity.
  return funcutils.repeated(project_and_normalize, iterations)(velocity)