time_stepping_test.py 6.9 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for time_stepping."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import tree_util
from jax import config
import jax.numpy as jnp
from jax_cfd.base import funcutils
from jax_cfd.spectral import time_stepping
import numpy as np


def harmonic_oscillator(x0, t):
  theta = jnp.arctan(x0[0] / x0[1])
  r = jnp.linalg.norm(x0, ord=2, axis=0)
  return r * jnp.stack([jnp.sin(t + theta), jnp.cos(t + theta)])


class CustomODE(time_stepping.ImplicitExplicitODE):

  def __init__(self, explicit_terms, implicit_terms, implicit_solve):
    self.explicit_terms = explicit_terms
    self.implicit_terms = implicit_terms
    self.implicit_solve = implicit_solve


ALL_TEST_PROBLEMS = [
    # x(t) = np.ones(10)
    dict(testcase_name='_zero_derivative',
         explicit_terms=lambda x: 0 * x,
         implicit_terms=lambda x: 0 * x,
         implicit_solve=lambda x, eta: x,
         dt=1e-2,
         inner_steps=10,
         outer_steps=5,
         initial_state=np.ones(10),
         closed_form=lambda x0, t: x0,
         tolerances=[1e-12] * 5),
    # x(t) = 5 * t * np.ones(3)
    dict(testcase_name='_constant_derivative',
         explicit_terms=lambda x: 5 * jnp.ones_like(x),
         implicit_terms=lambda x: 0 * x,
         implicit_solve=lambda x, eta: x,
         dt=1e-2,
         inner_steps=10,
         outer_steps=5,
         initial_state=np.ones(3),
         closed_form=lambda x0, t: x0 + 5 * t,
         tolerances=[1e-12] * 5),
    # x(t) = np.arange(3) * np.exp(t)
    # Uses explicit terms only.
    dict(testcase_name='_linear_derivative_explicit',
         explicit_terms=lambda x: x,
         implicit_terms=lambda x: 0 * x,
         implicit_solve=lambda x, eta: x,
         dt=1e-2,
         inner_steps=20,
         outer_steps=5,
         initial_state=np.arange(3.0),
         closed_form=lambda x0, t: np.arange(3) * jnp.exp(t),
         tolerances=[5e-2, 1e-4, 1e-6, 1e-9, 1e-6]),
    # x(t) = np.arange(3) * np.exp(t)
    # Uses implicit terms only.
    dict(testcase_name='_linear_derivative_implicit',
         explicit_terms=lambda x: 0 * x,
         implicit_terms=lambda x: x,
         implicit_solve=lambda x, eta: x / (1 - eta),
         dt=1e-2,
         inner_steps=20,
         outer_steps=5,
         initial_state=np.arange(3.0),
         closed_form=lambda x0, t: np.arange(3) * jnp.exp(t),
         tolerances=[5e-2, 5e-5, 1e-5, 1e-5, 3e-5]),
    # x(t) = np.arange(3) * np.exp(t)
    # Splits the equation into an implicit and explicit term.
    dict(testcase_name='_linear_derivative_semi_implicit',
         explicit_terms=lambda x: x / 2,
         implicit_terms=lambda x: x / 2,
         implicit_solve=lambda x, eta: x / (1 - eta / 2),
         dt=1e-2,
         inner_steps=20,
         outer_steps=5,
         initial_state=np.arange(3) * np.exp(0),
         closed_form=lambda x0, t: np.arange(3.0) * jnp.exp(t),
         tolerances=[1e-4, 2e-5, 2e-6, 1e-6, 2e-5]),
    dict(testcase_name='_harmonic_oscillator_explicit',
         explicit_terms=lambda x: jnp.stack([x[1], -x[0]]),
         implicit_terms=jnp.zeros_like,
         implicit_solve=lambda x, eta: x,
         dt=1e-2,
         inner_steps=20,
         outer_steps=5,
         initial_state=np.ones(2),
         closed_form=harmonic_oscillator,
         tolerances=[1e-2, 3e-5, 6e-8, 5e-11, 6e-8]),
    dict(testcase_name='_harmonic_oscillator_implicit',
         explicit_terms=jnp.zeros_like,
         implicit_terms=lambda x: jnp.stack([x[1], -x[0]]),
         implicit_solve=lambda x, eta: jnp.stack(  # pylint: disable=g-long-lambda
             [x[0] + eta * x[1], x[1] - eta * x[0]]) / (1 + eta ** 2),
         dt=1e-2,
         inner_steps=20,
         outer_steps=5,
         initial_state=np.ones(2),
         closed_form=harmonic_oscillator,
         tolerances=[1e-2, 2e-5, 2e-6, 1e-6, 6e-6]),
]


ALL_TIME_STEPPERS = [
    time_stepping.backward_forward_euler,
    time_stepping.crank_nicolson_rk2,
    time_stepping.crank_nicolson_rk3,
    time_stepping.crank_nicolson_rk4,
    time_stepping.imex_rk_sil3,
]


class TimeSteppingTest(parameterized.TestCase):

  @parameterized.named_parameters(ALL_TEST_PROBLEMS)
  def test_implicit_solve(
      self,
      explicit_terms,
      implicit_terms,
      implicit_solve,
      dt,
      inner_steps,
      outer_steps,
      initial_state,
      closed_form,
      tolerances,
  ):
    """Tests that time integration is accurate for a range of test cases."""
    del dt, explicit_terms, inner_steps, outer_steps, closed_form  # unused
    del tolerances  # unused

    # Verifies that `implicit_solve` solves (y - eta * F(y)) = x
    # This does not test the integrator, but rather verifies that the test
    # case is valid.
    eta = 0.3
    solved_state = implicit_solve(initial_state, eta)
    reconstructed_state = solved_state - eta * implicit_terms(solved_state)
    np.testing.assert_allclose(reconstructed_state, initial_state)

  @parameterized.named_parameters(ALL_TEST_PROBLEMS)
  def test_integration(
      self,
      explicit_terms,
      implicit_terms,
      implicit_solve,
      dt,
      inner_steps,
      outer_steps,
      initial_state,
      closed_form,
      tolerances,
  ):
    # Compute closed-form solution.
    time = dt * inner_steps * (1 + np.arange(outer_steps))
    expected = jax.vmap(closed_form, in_axes=(None, 0))(
        initial_state, time)

    # Compute trajectory using time-stepper.
    for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
      with self.subTest(time_stepper.__name__):
        equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
        semi_implicit_step = time_stepper(equation, dt)
        integrator = funcutils.trajectory(
            funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
        _, actual = integrator(initial_state)
        np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)

  def test_pytree_state(self):
    equation = CustomODE(
        explicit_terms=lambda x: tree_util.tree_map(jnp.zeros_like, x),
        implicit_terms=lambda x: tree_util.tree_map(jnp.zeros_like, x),
        implicit_solve=lambda x, eta: x,
    )
    u0 = {'x': 1.0, 'y': 1.0}
    for time_stepper in ALL_TIME_STEPPERS:
      with self.subTest(time_stepper.__name__):
        u1 = time_stepper(equation, 1.0)(u0)
        self.assertEqual(u0, u1)


if __name__ == '__main__':
  config.update('jax_enable_x64', True)
  absltest.main()