time_integrators.py 977 Bytes
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""Methods for time integration of first order differential equations."""
import gin
import haiku as hk
import jax


# TODO(dkochkov) Include other integrators such as DP; RK methods;
# TODO(dkochkov) Add option to include input as in funcutils.trajectory.


@gin.register
def euler_integrator(
    derivative_module,
    initial_state,
    dt,
    num_steps,
):
  """Integrates ode defined by `derivative_module` using euler method.

  Args:
    derivative_module: hk.Module that computes time derivative.
    initial_state: initial state for time integration.
    dt: time step.
    num_steps: number time steps `dt` to integrate for.

  Returns:
   final state at time `t + num_steps * dt` and `dt` spaced trajectory.
  """
  def _single_step(state, _):
    deriv = derivative_module(state)
    next_state = jax.tree_util.tree_map(lambda x, dxdt: x + dt * dxdt, state, deriv)
    return next_state, next_state

  return hk.scan(_single_step, initial_state, None, num_steps)