equations_test.py 13.4 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for spectral equations."""

from typing import Tuple

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import numpy as jnp
import jax_cfd.base as cfd
from jax_cfd.base import finite_differences
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import test_util
from jax_cfd.spectral import equations as spectral_equations
from jax_cfd.spectral import time_stepping

ALL_TIME_STEPPERS = [
    time_stepping.backward_forward_euler,
    time_stepping.crank_nicolson_rk2,
    time_stepping.crank_nicolson_rk3,
    time_stepping.crank_nicolson_rk4,
]

ALL_TIME_STEPPERS = [
    dict(testcase_name='_' + s.__name__, time_stepper=s)
    for s in ALL_TIME_STEPPERS
]


def roll(arr, offset: Tuple[int]):
  """Rolls an n-dim arr by offset."""
  assert len(offset) == len(arr.shape)
  for i, o in enumerate(offset):
    arr = jnp.roll(arr, o, axis=i)
  return arr


def get_grid(resolution, domain=(0, 2*jnp.pi)):
  return grids.Grid((resolution,), domain=(domain,))


def get_zeros_initial_condition(grid, dtype=jnp.complex64):
  n, = grid.shape
  return jnp.zeros(n // 2 + 1, dtype=dtype)


def get_sine_initial_condition(grid):
  xs, = grid.axes(offset=(0,))
  return jnp.fft.rfft(jnp.sin(xs))


class EquationsTest1D(test_util.TestCase):

  def test_ks_equation(self):
    """Test that the KS equation (1) does not explode and (2) conserves momentum."""
    size = 128
    outer_steps = 2100

    length = 10. * jnp.pi
    grid = cfd.grids.Grid((size,), domain=((0, length),))
    dx, = grid.step
    dt = dx / length

    # TODO(dresdner) make a parameterized test
    for smooth in [True, False]:
      step_fn = time_stepping.backward_forward_euler(
          spectral_equations.KuramotoSivashinsky(grid, smooth=smooth), dt)
      rollout_fn = jax.jit(cfd.funcutils.trajectory(step_fn, outer_steps))

      xs, = grid.axes()
      v0 = jnp.cos((1 / length) * xs)
      v0 = jnp.fft.rfft(v0)
      _, trajectory = jax.device_get(rollout_fn(v0))

      real_space_trajectory = jnp.fft.irfft(trajectory).real
      # ensure no explosion
      self.assertTrue(jnp.all(real_space_trajectory < 1e5))

      # conservation of momentum: momentum does not change over time
      initial_momentum = real_space_trajectory[0].sum()
      self.assertAllClose(
          initial_momentum, jnp.sum(real_space_trajectory, axis=1), atol=1e-3)

  @parameterized.named_parameters(
      dict(
          testcase_name='one_step_zeros',
          viscosity=0.01,
          grid=get_grid(128),
          time_step=0.01,
          initial_condition_fn=get_zeros_initial_condition,
          num_steps=1,
      ),
      dict(
          testcase_name='one_step_sine',
          viscosity=0.01,
          grid=get_grid(128),
          time_step=0.01,
          initial_condition_fn=get_sine_initial_condition,
          num_steps=1),
      dict(
          testcase_name='many_step_zeros',
          viscosity=0.01,
          grid=get_grid(128),
          time_step=0.01,
          initial_condition_fn=get_zeros_initial_condition,
          num_steps=1000),
      dict(
          testcase_name='many_step_sine',
          viscosity=0.01,
          grid=get_grid(128),
          time_step=0.01,
          initial_condition_fn=get_sine_initial_condition,
          num_steps=1000),
  )
  def test_burgers_equation(self, viscosity, grid, time_step,
                            initial_condition_fn, num_steps):
    """Check that the trajectories don't give NaNs."""
    eq = spectral_equations.BurgersEquation(viscosity=viscosity, grid=grid)
    step_fn = time_stepping.crank_nicolson_rk2(eq, time_step)
    step_fn = cfd.funcutils.repeated(step_fn, num_steps)
    uhat0 = initial_condition_fn(grid)
    t0 = 0.0
    uhat1, _ = step_fn((uhat0, t0))
    self.assertFalse(jnp.isnan(uhat1).any())

  @parameterized.named_parameters(
      dict(
          testcase_name='one_step_zeros',
          viscosity=0.01,
          grid=get_grid(128),
          time_step=0.01,
          initial_condition_fn=get_zeros_initial_condition,
          num_steps=1,
      ),
      dict(
          testcase_name='many_step_zeros',
          viscosity=0.01,
          grid=get_grid(128),
          time_step=0.01,
          initial_condition_fn=get_zeros_initial_condition,
          num_steps=1000),
  )
  def test_forced_burgers_equation(self, viscosity, grid, time_step,
                                   initial_condition_fn, num_steps):
    """Check that the trajectories don't give NaNs."""
    eq = spectral_equations.ForcedBurgersEquation(
        viscosity=viscosity, grid=grid)
    step_fn = time_stepping.crank_nicolson_rk2(eq, time_step)
    step_fn = cfd.funcutils.repeated(step_fn, num_steps)
    uhat0 = initial_condition_fn(grid)
    t0 = 0.0
    uhat1, _ = step_fn((uhat0, t0))
    self.assertFalse(jnp.isnan(uhat1).any())

  def test_nls_equation(self):
    """Check that trajectory matches Peregrine soliton analytic solution.

    Soln from https://en.wikipedia.org/wiki/Peregrine_soliton,
    however as we implement `psi_t = -i psi_xx/8 - i|psi|^2 psi/2`
    rather than `psi_t = +i psi_xx/2 -+i|psi|^2 psi` from the wiki,
    the solution needs to be rescaled and conjugated.
    """

    def solve_nls(u0, t_final=1., max_samples=1024, dt=1e-2, extent=500):
      N = len(u0)  # pylint: disable=invalid-name
      grid = grids.Grid((N,), domain=((-extent / 2, extent / 2),))
      xs, = grid.axes(offset=(0,))
      eq = spectral_equations.NonlinearSchrodinger(grid=grid)
      stepfn = time_stepping.crank_nicolson_rk4(eq, dt)
      uhat0 = jnp.fft.fft(u0)
      numsteps = int(t_final / dt)
      ds_period = max(numsteps // max_samples, 1)
      multistepfn = jax.jit(cfd.funcutils.repeated(stepfn, ds_period))
      _, uhat_traj = cfd.funcutils.trajectory(multistepfn, max_samples)(uhat0)
      u_traj = jax.vmap(jnp.fft.ifft)(uhat_traj)
      timesteps = (1 + jnp.arange(min(max_samples, numsteps))) * dt * ds_period
      return u_traj, xs, timesteps

    L = 40 * jnp.pi  # pylint: disable=invalid-name
    grid = grids.Grid((2**10,), domain=((-L / 2, L / 2),))
    dt = 3e-4
    tau = 8
    T = tau * 2  # pylint: disable=invalid-name
    xs, = grid.axes(offset=(0,))
    zs = xs * jnp.sqrt(2)
    u0 = (4 * zs**2 - 3) / (1 + 4 * zs**2)
    soln, x_ds, t_ds = solve_nls(u0, T, dt=dt, extent=L)
    z_ds = x_ds * jnp.sqrt(2)
    tau_ds = t_ds / 2
    gt_soln = 1 - 4 * (1 +
                       2j * tau_ds[:, None]) / (1 + 4 *
                                                (z_ds**2 + tau_ds[:, None]**2))
    gt_soln = jnp.conj(gt_soln * jnp.exp(1j * tau_ds[:, None]))
    self.assertLess(jnp.abs(soln - gt_soln).mean(), 1e-3)


class EquationsTest2D(test_util.TestCase):

  @parameterized.named_parameters(ALL_TIME_STEPPERS)
  def test_forced_turbulence(self, time_stepper):
    """Check that forced turbulence runs for 100 steps without blowing up."""
    grid = grids.Grid((128, 128), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
    v0 = cfd.initial_conditions.filtered_velocity_field(
        jax.random.PRNGKey(42), grid, 7, 4)
    vorticity0 = cfd.finite_differences.curl_2d(v0).data
    vorticity_hat0 = jnp.fft.rfftn(vorticity0)

    viscosity = 1e-3
    dt = 1e-5

    step_fn = time_stepper(
        spectral_equations.NavierStokes2D(
            viscosity,
            grid,
            forcing_fn=forcings.kolmogorov_forcing,
            drag=0.1), dt)

    trajectory_fn = cfd.funcutils.trajectory(step_fn, 100)
    _, trajectory = trajectory_fn(vorticity_hat0)
    self.assertTrue(jnp.all(~jnp.isnan(trajectory)))

  def test_viscosity(self):
    """Test that higher viscosity results in faster decay."""
    grid = grids.Grid((128, 128), domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))
    v0 = cfd.initial_conditions.filtered_velocity_field(
        jax.random.PRNGKey(42), grid, 7, 4)
    vorticity0 = cfd.finite_differences.curl_2d(v0).data
    vorticity_hat0 = jnp.fft.rfftn(vorticity0)

    norms = []
    for viscosity in [1e-3, 1e-1, 1]:
      dt = cfd.equations.stable_time_step(
          7, .5, viscosity, grid, implicit_diffusion=True)
      step_fn = time_stepping.crank_nicolson_rk4(
          spectral_equations.NavierStokes2D(
              viscosity,
              grid,
              forcing_fn=forcings.kolmogorov_forcing,
              drag=0.1), dt)

      trajectory_fn = cfd.funcutils.trajectory(step_fn, 100)
      _, trajectory = trajectory_fn(vorticity_hat0)

      norms.append(jnp.linalg.norm(trajectory))

    # higher viscosity means that you get to zero faster.
    self.assertLess(norms[2], norms[1])
    self.assertLess(norms[1], norms[0])

  @parameterized.named_parameters(
      dict(
          testcase_name='_TaylorGreen_SemiImplicitNavierStokes',
          problem=cfd.validation_problems.TaylorGreen(
              shape=(1024, 1024), density=1., viscosity=1e-3),
          equation=spectral_equations.NavierStokes2D,
          time_stepper=time_stepping.crank_nicolson_rk4,
          max_courant_number=.5,
          time=.11,
          atol=1e-3),)
  def test_accuracy(self, problem, equation, time_stepper, max_courant_number,
                    time, atol):
    """Check numerical accuracy of our solvers to known analytic solutions."""
    # This closely emulates a test in jax cfd:
    # https://source.corp.google.com/piper///depot/google3/third_party/py/jax_cfd/base/validation_test.py;l=113
    v0 = problem.velocity(0.)
    vorticity = finite_differences.curl_2d(v0).data

    dt = cfd.equations.stable_time_step(
        7,
        max_courant_number,
        problem.viscosity,
        problem.grid,
        implicit_diffusion=True)
    steps = int(jnp.ceil(time / dt))
    step_fn = time_stepper(
        equation(
            viscosity=problem.viscosity,
            grid=problem.grid,
            forcing_fn=None,
            drag=0), dt)

    _, vorticity_computed = cfd.funcutils.trajectory(
        cfd.funcutils.repeated(step_fn, steps), 1)(
            jnp.fft.rfftn(vorticity))

    v = problem.velocity(time)
    vorticity_analytic = finite_differences.curl_2d(v).data

    self.assertAllClose(
        jnp.fft.irfftn(vorticity_computed[0]), vorticity_analytic, atol=atol)

  @parameterized.named_parameters(
      dict(
          testcase_name='_decaying_turbulence',
          viscosity=1e-2,
          cfl_safety_factor=.1,
          max_velocity=2.0,
          peak_wavenumber=4,
          seed=0,
          density=1.0,
          n_steps=500,
          grid_size=512,
          is_forced=False,
          atol=0.09,
          ),
      dict(
          testcase_name='_forced_turbulence',
          viscosity=1e-2,
          cfl_safety_factor=.1,
          max_velocity=2.0,
          peak_wavenumber=4,
          seed=0,
          density=1.0,
          n_steps=150,
          grid_size=512,
          is_forced=True,
          atol=0.07,
          ),
      )
  def test_compare_to_finite_difference_method(self, viscosity,
                                               cfl_safety_factor, max_velocity,
                                               peak_wavenumber, seed, density,
                                               n_steps, grid_size,
                                               is_forced,
                                               atol):
    """Compare spectral to finite volume."""

    grid = cfd.grids.Grid((grid_size, grid_size),
                          domain=((0, 2 * jnp.pi), (0, 2 * jnp.pi)))

    # Construct a random initial velocity.
    v0 = cfd.initial_conditions.filtered_velocity_field(
        jax.random.PRNGKey(seed), grid, max_velocity)

    # Choose a time step.
    dt = cfd.equations.stable_time_step(max_velocity, cfl_safety_factor,
                                        viscosity, grid)

    if is_forced:
      fvm_forcing = forcings.simple_turbulence_forcing(
          grid,
          constant_magnitude=1,
          constant_wavenumber=4,
          linear_coefficient=-0.1,
          forcing_type='kolmogorov')

      eq = spectral_equations.ForcedNavierStokes2D(
          viscosity, grid, smooth=True)
    else:
      fvm_forcing = None
      eq = spectral_equations.NavierStokes2D(
          viscosity, grid, smooth=True, drag=0, forcing_fn=None)

    # use `repeated` since we only compare the final state
    fvm_rollout_fn = jax.jit(
        cfd.funcutils.repeated(
            cfd.equations.semi_implicit_navier_stokes(
                density=density,
                viscosity=viscosity,
                dt=dt,
                grid=grid,
                forcing=fvm_forcing),
            steps=n_steps))

    v = fvm_rollout_fn(v0)
    final_state_fvm = cfd.finite_differences.curl_2d(v).data

    spectral_rollout_fn = jax.jit(
        cfd.funcutils.repeated(time_stepping.crank_nicolson_rk4(eq, dt),
                               steps=n_steps))

    final_state_spectral = jnp.fft.irfftn(
        spectral_rollout_fn(
            jnp.fft.rfftn(
                roll(cfd.finite_differences.curl_2d(v0).data, (1, 1)))))

    self.assertAllClose(
        final_state_fvm, roll(final_state_spectral, (-1, -1)), atol=atol)


if __name__ == '__main__':
  absltest.main()