"doc/vscode:/vscode.git/clone" did not exist on "0295965d784ac46b5a75b9076aed21121be69eed"
utils.py 6.64 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions for building pseudospectral methods."""

from typing import Callable, Tuple

import jax.numpy as jnp
from jax_cfd.base import grids
from jax_cfd.spectral import types as spectral_types


def truncated_rfft(u: spectral_types.Array) -> spectral_types.Array:
  """Applies the 2/3 rule by truncating higher Fourier modes.

  Args:
    u: the real-space representation of the input signal

  Returns:
    Downsampled version of `u` in rfft-space.
  """
  uhat = jnp.fft.rfft(u)
  k, = uhat.shape
  final_size = int(2 / 3 * k) + 1
  return 2 / 3 * uhat[:final_size]


def padded_irfft(uhat: spectral_types.Array) -> spectral_types.Array:
  """Applies the 3/2 rule by padding with zeros.

  Args:
    uhat: the rfft representation of a signal

  Returns:
    An upsampled signal in real space which 3/2 times larger than the input
    signal `uhat`.
  """
  n, = uhat.shape
  final_shape = int(3 / 2 * n)
  smoothed = jnp.pad(uhat, (0, final_shape - n))
  assert smoothed.shape == (final_shape,), "incorrect padded shape"
  return 1.5 * jnp.fft.irfft(smoothed)


def truncated_fft_2x(u: spectral_types.Array) -> spectral_types.Array:
  """Applies the 1/2 rule to complex u by truncating higher Fourier modes.

  Args:
    u: the (complex) input signal

  Returns:
    Downsampled version of `u` in fft-space.
  """
  uhat = jnp.fft.fftshift(jnp.fft.fft(u))
  k, = uhat.shape
  final_size = (k + 1) // 2
  return jnp.fft.ifftshift(uhat[final_size // 2:(-final_size + 1) // 2]) / 2


def padded_ifft_2x(uhat: spectral_types.Array) -> spectral_types.Array:
  """Applies the 2x rule to complex F[u] by padding higher frequencies.

     Pads with zeros in the Fourier domain before performing the ifft
      (effectively performing 2x interpolation in the spatial domain)

  Args:
    uhat: the fft representation of signal

  Returns:
    An upsampled signal in real space interpolated to 2x more points than
    `jax.fft.ifft(uhat)`.
  """
  n, = uhat.shape
  final_size = n + 2 * (n // 2)
  added = n // 2
  smoothed = jnp.pad(jnp.fft.fftshift(uhat), (added, added))
  assert smoothed.shape == (final_size,), "incorrect padded shape"
  return 2 * jnp.fft.ifft(jnp.fft.ifftshift(smoothed))


def circular_filter_2d(grid: grids.Grid) -> spectral_types.Array:
  """Circular filter which roughly matches the 2/3 rule but is smoother.

  Follows the technique described in Equation 1 of [1]. We use a different value
  for alpha as used by pyqg [2].

  Args:
    grid: the grid to filter over

  Returns:
    Filter mask

  Reference:
    [1] Arbic, Brian K., and Glenn R. Flierl. "Coherent vortices and kinetic
    energy ribbons in asymptotic, quasi two-dimensional f-plane turbulence."
    Physics of Fluids 15, no. 8 (2003): 2177-2189.
    https://doi.org/10.1063/1.1582183

    [2] Ryan Abernathey, rochanotes, Malte Jansen, Francis J. Poulin, Navid C.
    Constantinou, Dhruv Balwada, Anirban Sinha, Mike Bueti, James Penn,
    Christopher L. Pitt Wolfe, & Bia Villas Boas. (2019). pyqg/pyqg: v0.3.0
    (v0.3.0). Zenodo. https://doi.org/10.5281/zenodo.3551326.
    See:
    https://github.com/pyqg/pyqg/blob/02e8e713660d6b2043410f2fef6a186a7cb225a6/pyqg/model.py#L136
  """
  kx, ky = grid.rfft_mesh()
  max_k = ky[-1, -1]

  circle = jnp.sqrt(kx**2 + ky**2)
  cphi = 0.65 * max_k
  filterfac = 23.6
  filter_ = jnp.exp(-filterfac * (circle - cphi)**4.)
  filter_ = jnp.where(circle <= cphi, jnp.ones_like(filter_), filter_)
  return filter_


def brick_wall_filter_2d(grid: grids.Grid):
  """Implements the 2/3 rule."""
  n, m = grid.shape
  filter_ = jnp.zeros((n, m // 2 + 1))
  filter_ = filter_.at[:int(2 / 3 * n) // 2, :int(2 / 3 * (m // 2 + 1))].set(1)
  filter_ = filter_.at[-int(2 / 3 * n) // 2:, :int(2 / 3 * (m // 2 + 1))].set(1)
  return filter_


def exponential_filter(signal, alpha=1e-6, order=2):
  """Apply a low-pass smoothing filter to remove noise from 2D signal."""
  # Based on:
  # 1. Gottlieb and Hesthaven (2001), "Spectral methods for hyperbolic problems"
  # https://doi.org/10.1016/S0377-0427(00)00510-0
  # 2. Also, see https://arxiv.org/pdf/math/0701337.pdf --- Eq. 5

  # TODO(dresdner) save a few ffts by factoring out the actual filter, sigma.
  alpha = -jnp.log(alpha)
  n, _ = signal.shape  # TODO(dresdner) check square / handle 1D case
  kx, ky = jnp.fft.fftfreq(n), jnp.fft.rfftfreq(n)
  kx, ky = jnp.meshgrid(kx, ky, indexing="ij")
  eta = jnp.sqrt(kx**2 + ky**2)
  sigma = jnp.exp(-alpha * eta**(2 * order))
  return jnp.fft.irfft2(sigma * jnp.fft.rfft2(signal))


def vorticity_to_velocity(
    grid: grids.Grid
) -> Callable[[spectral_types.Array], Tuple[spectral_types.Array,
                                            spectral_types.Array]]:
  """Constructs a function for converting vorticity to velocity, both in Fourier domain.

  Solves for the stream function and then uses the stream function to compute
  the velocity. This is the standard approach. A quick sketch can be found in
  [1].

  Args:
    grid: the grid underlying the vorticity field.

  Returns:
    A function that takes a vorticity (rfftn) and returns a velocity vector
    field.

  Reference:
    [1] Z. Yin, H.J.H. Clercx, D.C. Montgomery, An easily implemented task-based
    parallel scheme for the Fourier pseudospectral solver applied to 2D
    Navier–Stokes turbulence, Computers & Fluids, Volume 33, Issue 4, 2004,
    Pages 509-520, ISSN 0045-7930,
    https://doi.org/10.1016/j.compfluid.2003.06.003.
  """
  kx, ky = grid.rfft_mesh()
  two_pi_i = 2 * jnp.pi * 1j
  laplace = two_pi_i ** 2 * (abs(kx)**2 + abs(ky)**2)
  laplace = laplace.at[0, 0].set(1)  # pytype: disable=attribute-error  # jnp-type

  def ret(vorticity_hat):
    psi_hat = -1 / laplace * vorticity_hat
    vxhat = two_pi_i * ky * psi_hat
    vyhat = -two_pi_i * kx * psi_hat
    return vxhat, vyhat

  return ret


def filter_step(step_fn: spectral_types.StepFn, filter_: spectral_types.Array):
  """Returns a filtered version of the step_fn."""
  def new_step_fn(state):
    return filter_ * step_fn(state)
  return new_step_fn


def spectral_curl_2d(mesh, velocity_hat):
  """Computes the 2D curl in the Fourier basis."""
  kx, ky = mesh
  uhat, vhat = velocity_hat
  return 2j * jnp.pi * (vhat * kx - uhat * ky)