decoders.py 5.35 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Decoder modules that help interfacing model states with output data.

All decoder modules generate a function that given an specific model state
return the observable data of the same structure as provided to the Encoder.
Decoders can be either fixed functions, decorators, or learned modules.
"""

from typing import Any, Callable, Optional

import gin
import haiku as hk
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import grids
from jax_cfd.base import interpolation
from jax_cfd.ml import physics_specifications
from jax_cfd.ml import towers
from jax_cfd.spectral import utils as spectral_utils


DecodeFn = Callable[[Any], Any]  # maps model state to data time slice.
DecoderModule = Callable[..., DecodeFn]  # generate DecodeFn closed over args.
TowerFactory = towers.TowerFactory


@gin.register
def identity_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
  """Identity decoder module that returns model state as is."""
  del grid, dt, physics_specs  # unused.
  def decode_fn(inputs):
    return inputs

  return decode_fn


# TODO(dkochkov) generalize this to arbitrary pytrees.
@gin.register
def aligned_array_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
  """Generates decoder that extracts data from GridVariables."""
  del grid, dt, physics_specs  # unused.
  def decode_fn(inputs):
    return tuple(x.data for x in inputs)

  return decode_fn


@gin.register
def staggered_to_collocated_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
):
  """Decoder that interpolates from staggered to collocated grids."""
  del dt, physics_specs  # unused.
  def decode_fn(inputs):
    interp_inputs = [interpolation.linear(c, grid.cell_center) for c in inputs]
    return tuple(x.data for x in interp_inputs)

  return decode_fn


@gin.register
def channels_split_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
  """Generates decoder that splits channels into data tuples."""
  del grid, dt, physics_specs  # unused.
  def decode_fn(inputs):
    return array_utils.split_axis(inputs, -1)

  return decode_fn


@gin.register
def latent_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    tower_factory: TowerFactory,
    num_components: Optional[int] = None,
):
  """Generates trainable decoder that maps latent representation to data tuple.

  Decoder first computes an array of outputs using network specified by a
  `tower_factory` and then splits the channels into `num_components` components.

  Args:
    grid: grid representing spatial discritization of the system.
    dt: time step to use for time evolution.
    physics_specs: physical parameters of the simulation.
    tower_factory: factory that produces trainable tower network module.
    num_components: number of data tuples in the data representation of the
      state. If None, assumes num_components == grid.ndims. Default is None.

  Returns:
    decode function that maps latent state `inputs` at given time to a tuple of
    `num_components` data arrays representing the same state at the same time.
  """
  split_channels_fn = channels_split_decoder(grid, dt, physics_specs)

  def decode_fn(inputs):
    num_channels = num_components or grid.ndim
    decoder_tower = tower_factory(num_channels, grid.ndim, name='decoder')
    return split_channels_fn(decoder_tower(inputs))

  return hk.to_module(decode_fn)()


@gin.register
def aligned_latent_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
    tower_factory: TowerFactory,
    num_components: Optional[int] = None,
):
  """Latent decoder that decodes from aligned arrays."""
  split_channels_fn = channels_split_decoder(grid, dt, physics_specs)

  def decode_fn(inputs):
    inputs = jnp.stack([x.data for x in inputs], axis=-1)
    num_channels = num_components or grid.ndim
    decoder_tower = tower_factory(num_channels, grid.ndim, name='decoder')
    return split_channels_fn(decoder_tower(inputs))

  return hk.to_module(decode_fn)()


@gin.register
def vorticity_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
  """Solves for velocity and converts into GridVariables."""
  del dt, physics_specs  # unused.
  velocity_solve = spectral_utils.vorticity_to_velocity(grid)
  def decode_fn(vorticity):
    # TODO(dresdner) note the main difference is the input, which is in real space instead of vorticity space
    vorticity = jnp.squeeze(vorticity, axis=-1)  # remove channel dim
    vorticity_hat = jnp.fft.rfft2(vorticity)
    uhat, vhat = velocity_solve(vorticity_hat)
    v = (jnp.fft.irfft2(uhat), jnp.fft.irfft2(vhat))
    return v

  return decode_fn


@gin.register
def spectral_vorticity_decoder(
    grid: grids.Grid,
    dt: float,
    physics_specs: physics_specifications.BasePhysicsSpecs,
) -> DecodeFn:
  """Solves for velocity and converts into GridVariables."""
  del dt, physics_specs  # unused.
  velocity_solve = spectral_utils.vorticity_to_velocity(grid)
  def decode_fn(vorticity_hat):
    uhat, vhat = velocity_solve(vorticity_hat)
    v = (jnp.fft.irfft2(uhat), jnp.fft.irfft2(vhat))
    return v

  return decode_fn