funcutils.py 4.05 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""JAX utility functions for JAX-CFD."""

import contextlib
from typing import Any, Callable, Sequence

import jax
from jax import tree_util
import jax.numpy as jnp


# Specifying the full signatures of Callable would get somewhat onerous
# pylint: disable=g-bare-generic
# Not accurate for contextmanager
# pylint: disable=g-doc-return-or-yield

# There is currently no good way to indicate a jax "pytree" with arrays at its
# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more
# information about PyTrees and https://github.com/google/jax/issues/3340 for
# discussion of this issue.
PyTree = Any


_INITIALIZING = 0


@contextlib.contextmanager
def init_context():
  """Creates a context in which scan() only evaluates f() once.

  This is useful for initializing a neural net with Haiku that involves modules
  that are applied inside scan(). Within init_context(), these modules are only
  called once. This allows us to preserve the pre-omnistaging behavior of JAX,
  e.g., so we can initialize a neural net module pass directly into a scanned
  function.
  """
  global _INITIALIZING
  _INITIALIZING += 1
  try:
    yield
  finally:
    _INITIALIZING -= 1


def _tree_stack(trees: Sequence[PyTree]) -> PyTree:
  if trees:
    return tree_util.tree_map(lambda *xs: jnp.stack(xs), *trees)
  else:
    return trees


def scan(f, init, xs, length=None):
  """A version of jax.lax.scan that supports init_context()."""
  # Note: we use our own version of scan rather than haiku.scan() because
  # haiku.scan() only support use inside haiku modules, but we want to be able
  # to use the same scan function even when not using haiku.
  if _INITIALIZING:
    xs_flat, treedef = tree_util.tree_flatten(xs)
    if length is None:
      length, = {x.shape[0] for x in xs_flat}
    x0 = tree_util.tree_unflatten(treedef, [x[0, ...] for x in xs_flat])
    carry, y0 = f(init, x0)
    # Create a dummy-output of the right shape while only calling f() once.
    ys = _tree_stack(length * [y0])
    return carry, ys
  return jax.lax.scan(f, init, xs, length)


def repeated(f: Callable, steps: int) -> Callable:
  """Returns a repeatedly applied version of f()."""
  def f_repeated(x_initial):
    g = lambda x, _: (f(x), None)
    x_final, _ = scan(g, x_initial, xs=None, length=steps)
    return x_final
  return f_repeated


def _identity(x):
  return x


def trajectory(
    step_fn: Callable,
    steps: int,
    post_process: Callable = _identity,
    *,
    start_with_input: bool = False,
) -> Callable:
  """Returns a function that accumulates repeated applications of `step_fn`.

  Args:
    step_fn: function that takes a state and returns state after one time step.
    steps: number of steps to take when generating the trajectory.
    post_process: transformation to be applied to each frame of the trajectory.
    start_with_input: if True, output the trajectory at steps [0, ..., steps-1]
      instead of steps [1, ..., steps].

  Returns:
    A function that takes an initial state and returns a tuple consisting of:
      (1) the final frame of the trajectory _before_ `post_process` is applied.
      (2) trajectory of length `steps` representing time evolution.
  """
  # TODO(shoyer): change the default to start_with_input=True, once we're sure
  # it works for training.
  def step(carry_in, _):
    carry_out = step_fn(carry_in)
    frame = post_process(carry_in if start_with_input else carry_out)
    return carry_out, frame

  def multistep(values):
    return scan(step, values, xs=None, length=steps)

  return multistep