forcings.py 1.54 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Forcing functions for spectral equations."""

import jax
import jax.numpy as jnp
from jax_cfd.base import grids


def random_forcing_module(grid: grids.Grid,
                          seed: int = 0,
                          n: int = 20,
                          offset=(0,)):
  """Implements the forcing described in Bar-Sinai et al. [*].

  Args:
    grid: grid to use for the x-axis
    seed: random seed for computing the random waves
    n: number of random waves to use
    offset: offset for the x-axis. Defaults to (0,) for the Fourier basis.
  Returns:
    Time dependent forcing function.

  [*] Bar-Sinai, Yohai, Stephan Hoyer, Jason Hickey, and Michael P. Brenner.
  "Learning data-driven discretizations for partial differential equations."
  Proceedings of the National Academy of Sciences 116, no. 31 (2019):
  15344-15349.
  """

  key = jax.random.PRNGKey(seed)

  ks = jnp.array([3, 4, 5, 6])

  key, subkey = jax.random.split(key)
  kx = jax.random.choice(subkey, ks, shape=(n,))

  key, subkey = jax.random.split(key)
  amplitude = jax.random.uniform(subkey, minval=-0.5, maxval=0.5, shape=(n,))

  key, subkey = jax.random.split(key)
  omega = jax.random.uniform(subkey, minval=-0.4, maxval=0.4, shape=(n,))

  key, subkey = jax.random.split(key)
  phi = jax.random.uniform(subkey, minval=0, maxval=2 * jnp.pi, shape=(n,))

  xs, = grid.axes(offset=offset)

  def forcing_fn(t):

    @jnp.vectorize
    def eval_force(x):
      f = amplitude * jnp.sin(omega * t - x * kx + phi)
      return f.sum()

    return eval_force(xs)

  return forcing_fn