towers.py 11.1 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""Definitions of towers (neural networks based on multioke CNN layers)."""

import functools
from typing import Any, Callable, List, Optional, Tuple, Union
import gin
import jax
import haiku as hk

from jax_cfd.ml import layers
from jax_cfd.ml import nonlinearities

Array = layers.Array
ConvModule = Callable[..., Any]
ScaleFn = Callable[[Array, List[int]], Array]
TowerFactory = Callable[..., Any]


PERIODIC_CONV_MODULES = {
    1: layers.PeriodicConv1D,
    2: layers.PeriodicConv2D,
    3: layers.PeriodicConv3D}

PERIODIC_CONV_TRANSPOSE_MODULES = {
    1: layers.PeriodicConvTranspose1D,
    2: layers.PeriodicConvTranspose2D,
    3: layers.PeriodicConvTranspose3D}


@gin.register
def periodic_convolution(
    output_channels: int,
    kernel_shape: Tuple[int, ...],
    ndim: int,
    **kwargs
):
  """Returns PeriodicConv module with specified parameters."""
  return PERIODIC_CONV_MODULES[ndim](output_channels, kernel_shape, **kwargs)


@gin.register
def periodic_transpose_convolution(
    output_channels: int,
    kernel_shape: Tuple[int, ...],
    ndim: int,
    rate: Optional[int] = None,
    **kwargs
):
  """Returns PeriodicConvTranspose module with specified parameters."""
  if rate is not None and rate != 1:
    raise ValueError('transpose convolutions do not support dilation rate')
  return PERIODIC_CONV_TRANSPOSE_MODULES[ndim](
      output_channels, kernel_shape, **kwargs)


@gin.register
def mirror_convolution(
    output_channels: int,
    kernel_shape: Tuple[int, ...],
    ndim: int,
    **kwargs
):
  """Returns MirrorConv2D module with specified parameters."""
  del ndim
  return layers.MirrorConv2D(
      output_channels, kernel_shape, **kwargs)


@gin.register
def fixed_scale(inputs: Array,
                axes: Tuple[int, ...],
                rescaled_one: float = gin.REQUIRED) -> Array:
  """Linearly scales `inputs` such that `1` maps to `rescaled_one`."""
  del axes  # unused.
  return inputs * rescaled_one


@gin.register
def fixed_scale_gridvar(
    inputs: Array,
    axes: Tuple[int, ...],
    rescaled_one: float = gin.REQUIRED
) ->Array:
  """Linearly scales `inputs` such that `1` maps to `rescaled_one`."""
  del axes  # unused.
  return tuple(x.bc.impose_bc(x.array * rescaled_one) for x in inputs)  # pytype: disable=bad-return-type  # jax-devicearray


@gin.register
def scale_to_range(
    inputs: Array,
    axes: Tuple[int, ...],
    min_value: float = gin.REQUIRED,
    max_value: float = gin.REQUIRED,
) -> Array:
  """Dynamically scales `inputs` to be in `[min_value, max_value]` range.

  This scaling function represents a shift and scale transform that forces every
  `axes` slice of `inputs` to be exactly in range `[min_value, max_value]`.
  For details see `layers.rescale_to_range`.

  Args:
    inputs: array values to be rescaled.
    axes: tuple of ints representing axes over which the scaling is calculated.
    min_value: minimum value to appear in the rescaled values.
    max_value: maximum value to appear in the rescaled values.

  Returns:
    `inputs` scale to `[min_value, max_value]` range on every `axes` slice.
  """
  return layers.rescale_to_range(inputs, min_value, max_value, axes)


@gin.register
class MlpTowerFactory(hk.Module):
  """Tower that applies shared MLP to inputs over spatial dimensions."""

  def __init__(
      self,
      output_size: int,
      ndim: int,
      num_hidden_units: int,
      num_hidden_layers: int,
      nonlinearity: Callable[[Array], Array] = nonlinearities.relu,
      inputs_scale_fn: ScaleFn = lambda x, axes: x,
      output_scale_fn: ScaleFn = lambda x, axes: x,
      name: Optional[str] = 'mlp_tower_factory',
  ):
    super().__init__(name=name)
    output_sizes = [num_hidden_units] * num_hidden_layers + [output_size]
    mlp_net = hk.nets.MLP(output_sizes, activation=nonlinearity)
    for _ in range(ndim):
      mlp_net = hk.vmap(mlp_net, split_rng=False)
    ndim_axes = list(range(ndim))
    self.inputs_scale_fn = functools.partial(inputs_scale_fn, axes=ndim_axes)
    self.output_scale_fn = functools.partial(output_scale_fn, axes=ndim_axes)
    self.mlp_tower = mlp_net

  def __call__(self, inputs):
    """Applied Mlp tower to `inputs`."""
    return self.output_scale_fn(self.mlp_tower(self.inputs_scale_fn(inputs)))


@gin.register
def forward_tower_factory(
    num_output_channels: int,
    ndim: int,
    num_hidden_channels: int = 16,
    kernel_size: int = 3,
    num_hidden_layers: int = 2,
    rates: Union[int, Tuple[int, ...]] = 1,
    strides: Union[int, Tuple[int, ...]] = 1,
    output_kernel_size: int = 3,
    output_dilation_rate: int = 1,
    output_stride: int = 1,
    conv_module: ConvModule = periodic_convolution,
    nonlinearity: Callable[[Array], Array] = nonlinearities.relu,
    inputs_scale_fn: ScaleFn = lambda x, axes: x,
    output_scale_fn: ScaleFn = lambda x, axes: x,
    name: Optional[str] = 'forward_cnn_tower',
):
  """Constructs parametrized feed-forward CNN tower.

  Constructs CNN tower parametrized by fixed number of channels in hidden layers
  and fixed square kernels.

  Args:
    num_output_channels: number of channels in the output layer.
    ndim: number of spatial dimensions to expect in inputs to the network.
    num_hidden_channels: number of channels to use in hidden layers.
    kernel_size: size of the kernel to use along every dimension.
    num_hidden_layers: number of hidden layers to construct in the tower.
    rates: dilation rate(s) of the hidden layers.
    strides: strides to use. Must be `int` or same a `num_hidden_layers`.
    output_kernel_size: size of the output kernel to use along every dimension.
    output_dilation_rate: dilation_rate of the output layer.
    output_stride: stride of the final convolution.
    conv_module: convolution module to use. Must accept
      (output channels, kernel shape and ndim).
    nonlinearity: nonlinearity function to apply between hidden layers.
    inputs_scale_fn: scaling function to be applied to the inputs of the tower.
      Must take inputs as argument and return an `Array` of the same shape.
      Can expect an `axes` arguments specifying spatial axes in inputs.
    output_scale_fn: similar to `inputs_scale_fn` but applied to outputs.
    name: a name for this CNN tower. This name will appear in Xprof traces.

  Returns:
    CNN tower with specified configuration.
  """
  channels = (num_hidden_channels,) * num_hidden_layers
  kernel_shapes = ((kernel_size,) * ndim,) * num_hidden_layers
  output_kernel_shape = (output_kernel_size,) * ndim
  return forward_flex_tower_factory(
      num_output_channels=num_output_channels, ndim=ndim, channels=channels,
      kernel_shapes=kernel_shapes, rates=rates, strides=strides,
      output_kernel_shape=output_kernel_shape, output_rate=output_dilation_rate,
      output_stride=output_stride, conv_module=conv_module,
      nonlinearity=nonlinearity, inputs_scale_fn=inputs_scale_fn,
      output_scale_fn=output_scale_fn, name=name)


@gin.register
def forward_flex_tower_factory(
    num_output_channels: int,
    ndim: int,
    channels: Tuple[int, ...] = (16, 16),
    kernel_shapes: Tuple[Tuple[int, ...], ...] = ((3, 3), (3, 3)),
    rates: Tuple[int, ...] = (1, 1),
    strides: Tuple[int, ...] = (1, 1),
    output_kernel_shape: Tuple[int, ...] = (3, 3),
    output_rate: int = 1,
    output_stride: int = 1,
    conv_module: ConvModule = periodic_convolution,
    nonlinearity: Callable[[Array], Array] = nonlinearities.relu,
    inputs_scale_fn: ScaleFn = lambda x, axes: x,
    output_scale_fn: ScaleFn = lambda x, axes: x,
    name: Optional[str] = 'forward_flex_cnn_tower',
):
  """Constructs CNN tower with specified architecture.

  Args:
    num_output_channels: number of channels in the output layer.
    ndim: number of spatial dimensions to expect in inputs to the network.
    channels: tuple specifying number of channels in hidden layers.
    kernel_shapes: tuple specifying shapes of kernels in hidden layers.
      Each entry must be a tuple that specifies a valid kernel_shape for the
      provided `conv_module`. Must have the same length as `channels`.
    rates: dilation rates of the convolutions.
    strides: strides to use in convolutions.
    output_kernel_shape: shape of the output kernel.
    output_rate: dilation rate of the final convolution.
    output_stride: stride of the final convolution.
    conv_module: convolution module to use. Must accept
      (output channels, kernel shape and ndim).
    nonlinearity: nonlinearity function to apply between hidden layers.
    inputs_scale_fn: scaling function to be applied to the inputs of the tower.
      Must take `inputs`, `axes` arguments specifying input `Array` and
      spatial dimensions and return an `Array` of the same shape as `inputs`.
    output_scale_fn: similar to `inputs_scale_fn` but applied to outputs.
    name: a name for this CNN tower. This name will appear in Xprof traces.

  Returns:
    CNN tower with specified architecture.
  """
  if isinstance(strides, int):
    strides = (strides,) * len(channels)
  if isinstance(rates, int):
    rates = (rates,) * len(channels)

  ndim_axes = list(range(ndim))
  n_convs = len(channels)
  if not all(len(arg) == n_convs for arg in [kernel_shapes, rates, strides]):
    raise ValueError('conflicting lengths for channels/kernels/rates/strides: '
                     f'{channels} / {kernel_shapes} / {rates} / {strides}')
  def forward_pass(inputs):
    components = [functools.partial(inputs_scale_fn, axes=ndim_axes)]
    conv_args = zip(channels, kernel_shapes, rates, strides)
    for num_channels, kernel_shape, rate, stride in conv_args:
      components.append(conv_module(num_channels, kernel_shape, ndim, rate=rate,
                                    stride=stride))
      components.append(nonlinearity)
    components.append(conv_module(num_output_channels, output_kernel_shape,
                                  ndim, rate=output_rate, stride=output_stride))
    components.append(functools.partial(output_scale_fn, axes=ndim_axes))
    return hk.Sequential(components)(inputs)

  module = hk.to_module(forward_pass)(name=name)
  return jax.named_call(module, name=name)


@gin.register
def residual_block_tower_factory(
    num_output_channels: int,
    ndim: int,
    num_blocks: int = 2,
    block_factory: TowerFactory = forward_tower_factory,
    skip_connection_fn: Callable[..., Array] = lambda x, block_num: x,
    inputs_scale_fn: ScaleFn = lambda x, axes: x,
    output_scale_fn: ScaleFn = lambda x, axes: x,
    name: Optional[str] = 'residual_block_tower',
):
  """Constructs a tower with skip connections between blocks."""
  def forward_pass(inputs):
    inputs = inputs_scale_fn(inputs, list(range(ndim)))
    for block_num in range(num_blocks - 1):
      skip = skip_connection_fn(inputs, block_num)
      block = block_factory(skip.shape[-1], ndim)
      inputs = skip + block(inputs)
    last_block = block_factory(num_output_channels, ndim)
    return output_scale_fn(last_block(inputs), list(range(ndim)))

  module = hk.to_module(forward_pass)(name=name)
  return jax.named_call(module, name=name)


@gin.register
def residual_connection(*args, module_factory, **kwargs):
  """Apply module_factory() as a residual correction to inputs."""
  def forward_pass(inputs):
    return inputs + module_factory(*args, **kwargs)(inputs)
  return hk.to_module(forward_pass)()