encoders.py 8.75 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""Encoder modules that help interfacing input trajectories to model states.

All encoder modules generate a function that given an input trajectory infers
the final state of the physical system in the representation defined by the
Encoder. Encoders can be either fixed functions, decorators or learned modules.
The input state is expected to consist of arrays with `time` as a leading axis.
"""

from typing import Any, Callable, Optional, Tuple
import gin
import haiku as hk
import jax
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import towers


EncodeFn = Callable[[Any], Any]  # maps input trajectory to final model state.
EncoderModule = Callable[..., EncodeFn]  # generate EncodeFn closed over args.
TowerFactory = towers.TowerFactory


@gin.register
def aligned_array_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
  """Generates encoder that wraps last data slice as GridVariables."""
  del dt  # unused.
  if hasattr(physics_specs, 'combo_offsets'):
    data_offsets = physics_specs.combo_offsets()
  else:
    data_offsets = data_offsets or grid.cell_faces
  slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)

  def encode_fn(inputs):
    if hasattr(physics_specs, 'combo_boundaries'):
      bcs = physics_specs.combo_boundaries()
    else:
      bcs = tuple(
          boundaries.periodic_boundary_conditions(grid.ndim)
          for _ in range(len(inputs)))
    return tuple(
        bc.impose_bc(grids.GridArray(slice_last_fn(x), offset, grid))
        for x, offset, bc in zip(inputs, data_offsets, bcs))

  return encode_fn


@gin.register
def collocated_to_staggered_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
) -> EncodeFn:
  """Encoder that interpolates from collocated to staggered grids."""
  del dt, physics_specs  # unused.
  slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)

  def encode_fn(inputs):
    bc = boundaries.periodic_boundary_conditions(grid.ndim)
    src_offset = grid.cell_center
    pre_interp = tuple(
        grids.GridVariable(
            grids.GridArray(slice_last_fn(x), src_offset, grid), bc)
        for x in inputs)
    return tuple(interpolation.linear(c, offset)
                 for c, offset in zip(pre_interp, grid.cell_faces))

  return encode_fn


@gin.register
def slice_last_state_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    time_axis=0,
) -> EncodeFn:
  """Generates encoder that returns last data slice along time axis."""
  del grid, dt, physics_specs  # unused.
  def encode_fn(inputs):
    return array_utils.slice_along_axis(inputs, time_axis, -1)
  return encode_fn


@gin.register
def slice_last_n_state_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    n: int = gin.REQUIRED,
    time_axis: int = 0,
) -> EncodeFn:
  """Generates encoder that returns last `n` data slices along last axis."""
  del grid, dt, physics_specs  # unused.
  def encode_fn(inputs):
    init_slice = array_utils.slice_along_axis(inputs, 0, slice(-n, None))
    return jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, time_axis, -1), init_slice)
  return encode_fn


@gin.register
def stack_last_n_state_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    n: int = gin.REQUIRED,
    time_axis: int = 0,
) -> EncodeFn:
  """Generates encoder that stacks last `n` inputs slices along last axis."""
  del grid, dt, physics_specs  # unused.
  def encode_fn(inputs):
    inputs = array_utils.slice_along_axis(inputs, 0, slice(-n, None))
    inputs = jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, time_axis, -1), inputs)
    return array_utils.concat_along_axis(jax.tree_util.leaves(inputs), axis=-1)

  return encode_fn


@gin.register
def latent_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    tower_factory: TowerFactory,
    num_latent_dims: int,
    n_frames: int,
    time_axis: int = 0,
):
  """Generates trainable encoder that maps inputs to a latent representation.

  Encoder first stacks last `n_frames` time slices in input trajectory along
  channels and then applies a network specified by a `tower_factory` to obtain
  a latent field representation with `num_latent_dims` channel dimensions.

  Args:
    grid: grid representing spatial discritization of the system.
    dt: time step to use for time evolution.
    physics_specs: physical parameters of the simulation.
    tower_factory: factory that produces trainable tower network module.
    num_latent_dims: number of channels to have in latent representation.
    n_frames: number of last frames in input trajectory to use for encoding.
    time_axis: axis in input trajectory that correspond to time.

  Returns:
    encode function that maps input trajectory `inputs` to a latent field
    representation with `num_latent_dims`. Note that depending on the tower used
    the spatial dimension of the representation might differ from `inputs`.
  """

  stack_inputs_fn = stack_last_n_state_encoder(
      grid, dt, physics_specs, n_frames, time_axis)

  def encode_fn(inputs):
    inputs = stack_inputs_fn(inputs)
    encoder_tower = tower_factory(num_latent_dims, grid.ndim, name='encoder')
    return encoder_tower(inputs)

  return hk.to_module(encode_fn)()


@gin.register
def aligned_latent_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    tower_factory: TowerFactory,
    num_latent_dims: int,
    n_frames: int,
    time_axis: int = 0,
    data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
):
  """Latent encoder that decodes to GridVariables."""
  data_offsets = data_offsets or grid.cell_faces
  stack_inputs_fn = stack_last_n_state_encoder(
      grid, dt, physics_specs, n_frames, time_axis)

  def encode_fn(inputs):
    bc = boundaries.periodic_boundary_conditions(grid.ndim)
    inputs = stack_inputs_fn(inputs)
    encoder_tower = tower_factory(num_latent_dims, grid.ndim, name='encoder')
    raw_outputs = encoder_tower(inputs)
    split_outputs = [raw_outputs[..., i] for i in range(raw_outputs.shape[-1])]
    return tuple(
        grids.GridVariable(grids.GridArray(x, offset, grid), bc)
        for x, offset in zip(split_outputs, data_offsets))

  return hk.to_module(encode_fn)()


@gin.register
def vorticity_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
  """Maps velocity to vorticity."""
  del dt, physics_specs, data_offsets  # unused.
  slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)

  def encode_fn(inputs):
    u, v = inputs
    u, v = slice_last_fn(u), slice_last_fn(v)
    uhat, vhat = jnp.fft.rfft2(u), jnp.fft.rfft2(v)
    kx, ky = grid.rfft_mesh()
    vorticity_hat = 2j * jnp.pi * (vhat * kx - uhat * ky)
    # TODO(dresdner) main difference is that the output is ifft'ed.
    # TODO(dresdner) and also that the output has a channel dim.
    return jnp.fft.irfft2(vorticity_hat)[..., jnp.newaxis]

  return encode_fn


@gin.register
def vorticity_velocity_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
  """Maps velocity to [velocity; vorticity]."""
  del dt, physics_specs, data_offsets  # unused.
  slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)
  ifft = jnp.fft.irfft2

  def encode_fn(inputs):
    u, v = inputs
    u, v = slice_last_fn(u), slice_last_fn(v)
    uhat, vhat = jnp.fft.rfft2(u), jnp.fft.rfft2(v)
    kx, ky = grid.rfft_mesh()
    vorticity_hat = 2j * jnp.pi * (vhat * kx - uhat * ky)
    return jnp.stack([u, v, ifft(vorticity_hat)], axis=-1)

  return encode_fn


@gin.register
def spectral_vorticity_encoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    data_offsets: Optional[Tuple[Tuple[float, ...], ...]] = None,
) -> EncodeFn:
  """Generates encoder that wraps last data slice as GridVariables."""
  del dt, physics_specs, data_offsets  # unused.
  slice_last_fn = lambda x: array_utils.slice_along_axis(x, 0, -1)

  def encode_fn(inputs):
    u, v = inputs
    u, v = slice_last_fn(u), slice_last_fn(v)
    uhat, vhat = jnp.fft.rfft2(u), jnp.fft.rfft2(v)
    kx, ky = grid.rfft_mesh()
    vorticity = 2j * jnp.pi * (vhat * kx - uhat * ky)
    return vorticity

  return encode_fn