layers.py 27.3 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
"""Custom neural-net layers for physics simulations."""

import functools
from typing import (
    Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union,
)

import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
from jax_cfd.base import array_utils
from jax_cfd.base import boundaries
from jax_cfd.base import grids
from jax_cfd.ml import layers_util
from jax_cfd.ml import tiling
import numpy as np
import scipy.linalg

Array = Union[np.ndarray, jax.Array]
IntOrSequence = Union[int, Sequence[int]]


class PeriodicConvGeneral(hk.Module):
  """General periodic convolution module."""

  def __init__(
      self,
      base_convolution: Callable[..., Any],
      output_channels: int,
      kernel_shape: Tuple[int, ...],
      rate: int = 1,
      tile_layout: Optional[Tuple[int, ...]] = None,
      name: str = 'periodic_conv_general',
      **conv_kwargs: Any
  ):
    """Constructs PeriodicConvGeneral module.

    We use `VALID` padding on `base_convolution` and explicit padding to achieve
    the effect of periodic boundary conditions. This function computes
    `paddings` and combines `jnp.pad` function calls with `base_convolution`
    module to produce the dersired effect.

    Args:
      base_convolution: standard convolution module e.g. hk.Conv1D.
      output_channels: number of output channels.
      kernel_shape: shape of the kernel, compatible with `base_convolution`.
      rate: dilation rate of the convolution.
      tile_layout: optional layout for tiling spatial dimensions in a batch.
      name: name of the module.
      **conv_kwargs: additional arguments passed to `base_convolution`.
    """
    super().__init__(name=name)
    self._padding = []
    for kernel_size in kernel_shape:
      effective_kernel = kernel_size + (rate - 1) * (kernel_size - 1)
      pad_left = effective_kernel // 2
      self._padding.append((pad_left, effective_kernel - pad_left - 1))
    self._tile_layout = tile_layout
    self._conv_module = base_convolution(
        output_channels=output_channels, kernel_shape=kernel_shape,
        padding='VALID', rate=rate, **conv_kwargs)

  def __call__(self, inputs):
    return tiling.apply_convolution(
        self._conv_module, inputs, self._tile_layout, self._padding)


class PeriodicConv1D(PeriodicConvGeneral):
  """Periodic convolution module in 1D."""

  def __init__(
      self,
      output_channels: int,
      kernel_shape: Tuple[int],
      rate: int = 1,
      tile_layout: Optional[Tuple[int]] = None,
      name='periodic_conv_1d',
      **conv_kwargs
  ):
    """Constructs PeriodicConv1D module."""
    super().__init__(
        base_convolution=hk.Conv1D,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        rate=rate,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs
    )


class PeriodicConv2D(PeriodicConvGeneral):
  """Periodic convolution module in 2D."""

  def __init__(
      self,
      output_channels: int,
      kernel_shape: Tuple[int, int],
      rate: int = 1,
      tile_layout: Optional[Tuple[int, int]] = None,
      name='periodic_conv_2d',
      **conv_kwargs
  ):
    """Constructs PeriodicConv2D module."""
    super().__init__(
        base_convolution=hk.Conv2D,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        rate=rate,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs
    )


class PeriodicConv3D(PeriodicConvGeneral):
  """Periodic convolution module in 3D."""

  def __init__(
      self,
      output_channels: int,
      kernel_shape: Tuple[int, int, int],
      rate: int = 1,
      tile_layout: Optional[Tuple[int, int, int]] = None,
      name='periodic_conv_3d',
      **conv_kwargs
  ):
    """Constructs PeriodicConv3D module."""
    super().__init__(
        base_convolution=hk.Conv3D,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        rate=rate,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs
    )


class MirrorConvGeneral(hk.Module):
  """General periodic convolution module."""

  def __init__(self,
               base_convolution: Callable[..., Any],
               output_channels: int,
               kernel_shape: Tuple[int, ...],
               rate: int = 1,
               tile_layout: Optional[Tuple[int, ...]] = None,
               name: str = 'mirror_conv_general',
               **conv_kwargs: Any):
    """Constructs MirrorConvGeneral module.

    We use `VALID` padding on `base_convolution` and explicit padding beyond
    the boudaries. This function computes paddings` and combines `jnp.pad` 
    function calls with `base_convolution` module.

    Args:
      base_convolution: standard convolution module e.g. hk.Conv1D.
      output_channels: number of output channels.
      kernel_shape: shape of the kernel, compatible with `base_convolution`.
      rate: dilation rate of the convolution.
      tile_layout: optional layout for tiling spatial dimensions in a batch.
      name: name of the module.
      **conv_kwargs: additional arguments passed to `base_convolution`.
    """
    super().__init__(name=name)
    self._padding = np.zeros(2).astype('int')
    for kernel_size in kernel_shape:
      effective_kernel = kernel_size + (rate - 1) * (kernel_size - 1)
      pad_left = effective_kernel // 2
      self._padding += np.array([pad_left,
                                 effective_kernel - pad_left - 1]).astype('int')
    self._tile_layout = tile_layout
    self._conv_module = base_convolution(
        output_channels=output_channels, kernel_shape=kernel_shape,
        padding='VALID', rate=rate, **conv_kwargs)

  def _expand_var(self, var):
    var = var.bc.pad_all(
        var, (self._padding,) * var.grid.ndim, mode=boundaries.Padding.MIRROR)
    return jnp.expand_dims(var.data, axis=-1)

  def __call__(self, inputs):
    input_data = tuple(self._expand_var(var) for var in inputs)
    input_data = array_utils.concat_along_axis(
        jax.tree_util.leaves(input_data), axis=-1)
    outputs = self._conv_module(input_data)
    outputs = array_utils.split_axis(outputs, -1)
    outputs = tuple(
        var_input.bc.impose_bc(
            grids.GridArray(var, var_input.offset, var_input.grid))
        for var, var_input in zip(outputs, inputs))
    return outputs


class MirrorConv2D(MirrorConvGeneral):
  """Mirror convolution module in 2D."""

  def __init__(self,
               output_channels: int,
               kernel_shape: Tuple[int, int],
               rate: int = 1,
               tile_layout: Optional[Tuple[int, int]] = None,
               name='mirror_conv_2d',
               **conv_kwargs):
    """Constructs MirrorConv2D module."""
    super().__init__(
        base_convolution=hk.Conv2D,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        rate=rate,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs)


class PeriodicConvTransposeGeneral(hk.Module):
  """General periodic transpose convolution module."""

  def __init__(
      self,
      base_convolution: Callable[..., Any],
      output_channels: int,
      kernel_shape: Tuple[int, ...],
      stride: int = 1,
      tile_layout: Optional[Tuple[int, ...]] = None,
      name: str = 'periodic_conv_transpose_general',
      **conv_kwargs: Any
  ):
    """Constructs PeriodicConvTransposeGeneral module.

    To achieve the effect of periodic convolutions we first pad the inputs at
    the start of each spatial axis with wrap mode to ensure that the output
    generated by the original slice of the inputs receive contributions from
    periodic images when the `base_convolution` is applied. The
    `base_convolution` is applied with `VALID` padding followed by slicing
    to discard the boundary values. Additionally we perform a roll on the output
    to avoid the drift of spatial axes. (in standard implementation of the
    transposed convolutions the kernel applied to index [i] affects outputs at
    indicdes [i: i + kernel_size]. We perceive the center of the affected field
    as the spatial location of the output and hence shift it back by half of
    the kernel size.)

    Args:
      base_convolution: standard transpose convolution module.
      output_channels: number of output channels.
      kernel_shape: shape of the kernel, compatible with `base_convolution`.
      stride: stride to use in `base_convolution`.
      tile_layout: optional layout for tiling spatial dimensions in a batch.
      name: name of the module.
      **conv_kwargs: additional arguments passed to `base_convolution`.
    """
    if tile_layout is not None:
      raise NotImplementedError(
          "tile_layout doesn't work yet for transpose convolutions")
    super().__init__(name=name)
    self._stride = stride
    self._kernel_shape = kernel_shape
    self._padding = []
    self._roll_shifts = []
    for kernel_size in kernel_shape:
      # left pad should be large enough so that contribution from the leftmost
      # element just affect the output of the input's original first index: i.e.
      # stride * left_pad (output of the first index) should be less or equal to
      # kernel_size (last affected value of the input)
      pad_left = kernel_size // stride + 1
      self._padding.append((pad_left, 0))
      # we shift by half a kernel size at the end to recover spatial alignment.
      self._roll_shifts.append(-((kernel_size - 1) // 2))
    self._tile_layout = tile_layout
    self._conv_module = base_convolution(
        output_channels=output_channels, kernel_shape=kernel_shape,
        stride=stride, padding='VALID', **conv_kwargs)

  def __call__(self, inputs):
    """Applies PeriodicTransposeConvolution to `inputs`.

    Args:
      inputs: array with spatial and channel axes to which
        PeriodicTransposeConvolution is applied.

    Returns:
      `inputs` convolved with the kernel of the module with periodic padding.
    """
    ndim = len(self._kernel_shape)
    output_slice = []
    for axis, (left_pad, _) in enumerate(self._padding[:ndim]):
      axis_size = inputs.shape[axis]
      output_start = self._stride * left_pad
      output_end = self._stride * (axis_size + left_pad)
      output_slice.append(slice(output_start, output_end))
    output_slice.append(slice(None, None))
    output = tiling.apply_convolution(
        self._conv_module, inputs, self._tile_layout, self._padding)
    sliced_output = output[tuple(output_slice)]
    return jnp.roll(sliced_output, self._roll_shifts, list(range(ndim)))


class PeriodicConvTranspose1D(PeriodicConvTransposeGeneral):
  """Periodic transpose convolution module in 1D."""

  def __init__(
      self,
      output_channels: int,
      kernel_shape: Tuple[int],
      stride: int = 1,
      tile_layout: Optional[Tuple[int]] = None,
      name='periodic_conv_transpose_1d',
      **conv_kwargs
  ):
    """Constructs PeriodicConv1D module."""
    super().__init__(
        base_convolution=hk.Conv1DTranspose,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        stride=stride,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs
    )


class PeriodicConvTranspose2D(PeriodicConvTransposeGeneral):
  """Periodic transpose convolution module in 2D."""

  def __init__(
      self,
      output_channels: int,
      kernel_shape: Tuple[int, int],
      stride: int = 1,
      tile_layout: Optional[Tuple[int, int]] = None,
      name='periodic_conv_transpose_2d',
      **conv_kwargs
  ):
    """Constructs PeriodicConv2D module."""
    super().__init__(
        base_convolution=hk.Conv2DTranspose,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        stride=stride,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs
    )


class PeriodicConvTranspose3D(PeriodicConvTransposeGeneral):
  """Periodic transpose convolution module in 3D."""

  def __init__(
      self,
      output_channels: int,
      kernel_shape: Tuple[int, int, int],
      stride: int = 1,
      tile_layout: Optional[Tuple[int, int, int]] = None,
      name='periodic_conv_transpose_3d',
      **conv_kwargs
  ):
    """Constructs PeriodicConv3D module."""
    super().__init__(
        base_convolution=hk.Conv3DTranspose,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        stride=stride,
        tile_layout=tile_layout,
        name=name,
        **conv_kwargs
    )


def rescale_to_range(
    inputs: Array,
    min_value: float,
    max_value: float,
    axes: Tuple[int, ...]
) -> Array:
  """Rescales inputs to [min_value, max_value] range.

  Note that this function performs input dependent transformation, which might
  not be suitable for models that aim to learn different dynamics for different
  scales.

  Args:
    inputs: array to be rescaled to [min_value, max_value] range.
    min_value: value to which the smallest entry of `inputs` is mapped to.
    max_value: value to which the largest entry of `inputs` is mapped to.
    axes: `inputs` axes across which we search for smallest and largest values.

  Returns:
    `inputs` rescaled to [min_value, max_value] range.
  """
  inputs_max = jnp.max(inputs, axis=axes, keepdims=True)
  inputs_min = jnp.min(inputs, axis=axes, keepdims=True)
  scale = (inputs_max - inputs_min) / (max_value - min_value)
  return (inputs - inputs_min) / scale + min_value


class NonPeriodicConvGeneral(hk.Module):
  """General periodic convolution module."""

  def __init__(self,
               base_convolution: Callable[..., Any],
               output_channels: int,
               kernel_shape: Tuple[int, ...],
               rate: int = 1,
               name: str = 'periodic_conv_general',
               **conv_kwargs: Any):
    """Constructs NonPeriodicConvGeneral module similar to a Periodic one.

    Args:
      base_convolution: standard convolution module e.g. hk.Conv1D.
      output_channels: number of output channels.
      kernel_shape: shape of the kernel, compatible with `base_convolution`.
      rate: dilation rate of the convolution.
      name: name of the module.
      **conv_kwargs: additional arguments passed to `base_convolution`.
    """
    super().__init__(name=name)
    self._conv_module = base_convolution(
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        padding='VALID',
        rate=rate,
        **conv_kwargs)

  def __call__(self, inputs):
    output = self._conv_module(jnp.expand_dims(inputs, axis=0))
    return jnp.squeeze(output, axis=0)


class NonPeriodicConv1D(NonPeriodicConvGeneral):
  """Periodic convolution module in 1D."""

  def __init__(self,
               output_channels: int,
               kernel_shape: Tuple[int, ...],
               rate: int = 1,
               name='periodic_conv_1d',
               **conv_kwargs):
    """Constructs PeriodicConv1D module."""
    super().__init__(
        base_convolution=hk.Conv1D,
        output_channels=output_channels,
        kernel_shape=kernel_shape,
        rate=rate,
        name=name,
        **conv_kwargs)


class PolynomialConstraint():
  """Module that parametrizes coefficients of polynomially accurate derivatives.

  Generates stencil coefficients that are guaranteed to approximate derivative
  of `derivative_orders[i]` along ith dimension with polynomial accuracy of
  `accuracy_order` order. The approximation is enforced by taking a linear
  superposition of the nullspace of linear constraints combined with a bias
  solution, which can be either specified directly using `bias` argument or
  generated automatically using `layers_util.polynomial_accuracy_coefficients`.
  """

  def __init__(
      self,
      stencils: Sequence[np.ndarray],
      derivative_orders: Tuple[int, ...],
      method: layers_util.Method,
      steps: Tuple[float, ...],
      accuracy_order: int = 1,
      bias_accuracy_order: int = 1,
      bias: Optional[Array] = None,
      precision: lax.Precision = lax.Precision.HIGHEST
  ):
    """Constructs the object.

    Args:
      stencils: sequence of 1d stencils, one per grid dimension
      derivative_orders: derivative orders along corresponding directions.
      method: discretization method (finite volumes or finite differences).
      steps: spatial separations between the adjacent cells.
      accuracy_order: order to which polynomial accuracy is enforced.
      bias_accuracy_order: integer order of polynomial accuracy to use for the
        bias term. Only used if bias is not provided.
      bias: np.ndarray of shape (grid_size,) to which zero-vectors will be
        mapped. Must satisfy polynomial accuracy to the requested order. By
        default, we use standard low-order coefficients for the given grid.
      precision: numerical precision for matrix multplication. Only relevant on
        TPUs.
    """
    self.precision = precision
    grid_steps = {*steps}
    if len(grid_steps) != 1:
      raise ValueError('nonuniform steps not supported by PolynomialConstraint')
    grid_step, = grid_steps
    #  stencil coefficients `c` satisfying `constraint_matrix @ c = rhs`
    #  satisfies polynomial accuracy constraint of the given order
    constraint_matrix, rhs = layers_util.polynomial_accuracy_constraints(
        stencils, method, derivative_orders, accuracy_order, grid_step)

    if bias is None:
      bias_grid = layers_util.polynomial_accuracy_coefficients(
          stencils, method, derivative_orders, bias_accuracy_order, grid_step)
      bias = bias_grid.ravel()

    self.bias = bias
    norm = np.linalg.norm(np.dot(constraint_matrix, bias) - rhs)
    if norm > 1e-8:
      raise ValueError('invalid bias, not in nullspace')

    # https://en.wikipedia.org/wiki/Kernel_(linear_algebra)#Nonhomogeneous_systems_of_linear_equations
    _, _, v = np.linalg.svd(constraint_matrix)
    nullspace_size = constraint_matrix.shape[1] - constraint_matrix.shape[0]
    if not nullspace_size:
      raise ValueError(
          'there is only one valid solution accurate to this order')

    # nullspace from the SVD is always normalized such that its singular values
    # are 1 or 0, which means it's actually independent of the grid spacing.
    self._nullspace_size = nullspace_size
    self.nullspace = v[-nullspace_size:]
    self.nullspace /= (grid_step**np.array(derivative_orders)).prod()

  @property
  def subspace_size(self) -> int:
    """Returns the size of the coefficients subspace with desired accuracy."""
    return self._nullspace_size

  def __call__(
      self,
      inputs: Array
  ) -> Array:
    """Returns polynomially accurate coefficients parametrized by `inputs`.

    Args:
      inputs: array whose last dimension represents linear superposition of
        valid polynomially accurate coefficients. Last dimension must be equal
        to `subspace_size`.

    Returns:
      array whose last dimension represents valid coefficients that approximate
      `derivate_orders` with polynomial accuracy on a stencil specified in
      `stencils` at position (0.,) * ndims.
    """
    return self.bias + jnp.tensordot(inputs, self.nullspace, axes=[-1, 0],
                                     precision=self.precision)


class StencilCoefficients():
  """Module that approximates stencil coefficients with polynomial accuracy.

  Generates stencil coefficients that approximate a spatial derivative of
  order `derivative_orders[i]` along i'th dimension by combining a trainable
  model generated by `tower_factory` and `PolynomilConstraint` layer.
  """

  def __init__(
      self,
      stencils: Sequence[np.ndarray],
      derivative_orders: Tuple[int, ...],
      tower_factory: Callable[[int], Callable[..., Any]],
      steps: Tuple[float, ...],
      method: layers_util.Method = layers_util.Method.FINITE_VOLUME,
      **kwargs: Any,
  ):
    """Constructs the object.

    Args:
      stencils: sequence of 1d stencils, one per grid dimension
      derivative_orders: derivative orders along corresponding directions.
      tower_factory: callable that constructs a neural network with specified
        number of output channels and the same spatial resolution as its inputs.
      steps: spatial separations between the adjacent cells.
      method: discretization method passed to PolynomialConstraint.
      **kwargs: additional arguments to be passed to PolynomialConstraint
        constructor.
    """
    self._polynomial_constraint = PolynomialConstraint(
        stencils, derivative_orders, method, steps, **kwargs)
    self._tower = tower_factory(self._polynomial_constraint.subspace_size)

  def __call__(self, inputs: Array, **kwargs) -> Array:
    """Returns coefficients approximating derivative conditioned on `inputs`."""
    parametrization = self._tower(inputs, **kwargs)
    return self._polynomial_constraint(parametrization)


class SpatialDerivativeFromLogits:
  """Module that transforms logits to polynomially accurate derivatives.

  Applies `PolynomialConstraint` layer to input logits and combines the
  resulting coefficients with basis. Compared to `SpatialDerivative`, this
  module does not compute `logits`, but takes them as an argument.
  """

  def __init__(
      self,
      stencil_shape: Tuple[int, ...],
      input_offset: Tuple[float, ...],
      target_offset: Tuple[float, ...],
      derivative_orders: Tuple[int, ...],
      steps: Tuple[float, ...],
      extract_patch_method: str = 'roll',
      tile_layout: Optional[Tuple[int, ...]] = None,
      method: layers_util.Method = layers_util.Method.FINITE_VOLUME,
  ):
    self.stencil_shape = stencil_shape
    self.roll, shift = layers_util.get_roll_and_shift(
        input_offset, target_offset)
    stencils = layers_util.get_stencils(stencil_shape, shift, steps)
    self.constraint = PolynomialConstraint(
        stencils, derivative_orders, method, steps)
    self._extract_patch_method = extract_patch_method
    self.tile_layout = tile_layout

  @property
  def subspace_size(self) -> int:
    return self.constraint.subspace_size

  @property
  def stencil_size(self) -> int:
    return int(np.prod(self.stencil_shape))

  def _validate_logits(self, logits):
    if logits.shape[-1] != self.subspace_size:
      raise ValueError('The last dimension of `logits` did not match subspace '
                       f'size; {logits.shape[-1]} vs. {self.subspace_size}')

  def extract_patches(self, inputs):
    rolled = jnp.roll(inputs, self.roll)
    patches = layers_util.extract_patches(
        rolled, self.stencil_shape,
        self._extract_patch_method, self.tile_layout)
    return patches

  @functools.partial(jax.named_call, name='SpatialDerivativeFromLogits')
  def __call__(self, inputs, logits):
    self._validate_logits(logits)
    coefficients = self.constraint(logits)
    patches = self.extract_patches(inputs)
    return layers_util.apply_coefficients(coefficients, patches)


T = TypeVar('T')


def fuse_spatial_derivative_layers(
    derivatives: Dict[T, SpatialDerivativeFromLogits],
    all_logits: jnp.ndarray,
    *,
    constrain_with_conv: bool = False,
    fuse_patches: bool = False,
) -> Dict[T, Callable[[jnp.ndarray], jnp.ndarray]]:
  """Evaluate spatial derivatives by fusing together constraints.

  Despite the additional calculation, this can be faster on TPUs because the
  full block diagonal constraint matrix is small enough to fit within a 128x128
  matrix.

  Args:
    derivatives: mapping from key to SpatialDerivativeFromLogits.
    all_logits: stacked logits to use as input into spatial derivatives.
    constrain_with_conv: whether to constrain with a 1x1 convolution instead of
      direct matrix multiplication.
    fuse_patches: whether to also fuse the extraction of patches.

  Returns:
    Functions that when applied evaluate derivatives.
  """
  joint_bias = jnp.concatenate(
      [derivative.constraint.bias for derivative in derivatives.values()])
  joint_nullspace = scipy.linalg.block_diag(
      *[deriv.constraint.nullspace for deriv in derivatives.values()]
  )
  precision, = {deriv.constraint.precision for deriv in derivatives.values()}
  tile_layout, = {deriv.tile_layout for deriv in derivatives.values()}

  if constrain_with_conv:
    ndim = len(tile_layout)
    kernel = jnp.expand_dims(
        joint_nullspace.astype(np.float32), axis=tuple(range(ndim)))
    all_coefficients = joint_bias + layers_util.periodic_convolution(
        all_logits, kernel, tile_layout=tile_layout, precision=precision)
  else:
    if tile_layout is not None:
      all_logits = tiling.space_to_batch(all_logits, tile_layout)
    all_coefficients = joint_bias + jnp.tensordot(
        all_logits, joint_nullspace, axes=[-1, 0], precision=precision)
    if tile_layout is not None:
      all_coefficients = tiling.batch_to_space(all_coefficients, tile_layout)

  stencil_sizes = [deriv.stencil_size for deriv in derivatives.values()]
  coefficients_list = jnp.split(
      all_coefficients, np.cumsum(stencil_sizes), axis=-1)
  coefficients_map = dict(zip(derivatives, coefficients_list))

  stencil_shapes = [deriv.stencil_shape for k, deriv in derivatives.items()]
  for k, deriv in derivatives.items():
    if any(r != 0 for r in deriv.roll):
      raise ValueError(f'derivative {k} uses roll: {deriv.roll}')

  @functools.partial(jax.named_call, name='evaluate_derivatives')
  def evaluate(key, inputs):
    if fuse_patches:
      all_patches = layers_util.fused_extract_patches(
          inputs, stencil_shapes, tile_layout)
      all_terms = all_coefficients * all_patches
      split_terms = jnp.split(all_terms, np.cumsum(stencil_sizes), axis=-1)
      index = list(derivatives).index(key)
      return jnp.sum(split_terms[index], axis=-1, keepdims=True)
    else:
      patches = derivatives[key].extract_patches(inputs)
      return layers_util.apply_coefficients(coefficients_map[key], patches)

  return {k: functools.partial(evaluate, k) for k in derivatives}


class SpatialDerivative:
  """Module that learns spatial derivative with polynomial accuracy.

  Combines StencilCoefficients with extract_stencils to construct a trainable
  model that approximates spatial derivative.
  """

  def __init__(
      self,
      stencil_shape: Tuple[int, ...],
      input_offset: Tuple[float, ...],
      target_offset: Tuple[float, ...],
      derivative_orders: Tuple[int, ...],
      tower_factory: Callable[[int], Callable[..., Any]],
      steps: Tuple[float, ...],
      extract_patch_method: str = 'roll',
      tile_layout: Optional[Tuple[int, ...]] = None,
  ):
    self._stencil_shape = stencil_shape
    self._roll, self._shift = layers_util.get_roll_and_shift(
        input_offset, target_offset)
    stencils = layers_util.get_stencils(stencil_shape, self._shift, steps)
    self._coefficients_module = StencilCoefficients(
        stencils, derivative_orders, tower_factory, steps)
    self._extract_patch_method = extract_patch_method
    self._tile_layout = tile_layout

  @functools.partial(jax.named_call, name='SpatialDerivative')
  def __call__(self, inputs, *auxiliary_inputs):
    """Computes spatial derivative of `inputs` evaluated at `offset`."""
    # TODO(jamieas): consider moving this roll inside `extract_patches`. For the
    # `roll` implementation of `extract_patches`, we can simply add it to the
    # `shifts`. For the `conv` implementation, we may be able to effectively
    # roll the input by adjusting how arrays are padded.
    rolled = jnp.roll(inputs, self._roll)
    patches = layers_util.extract_patches(
        rolled, self._stencil_shape,
        self._extract_patch_method, self._tile_layout)
    if auxiliary_inputs is not None:
      auxiliary_inputs = [jnp.roll(aux, self._roll) for aux in auxiliary_inputs]
      rolled = jnp.concatenate([rolled, *auxiliary_inputs], axis=-1)
    coefficients = self._coefficients_module(rolled)
    return layers_util.apply_coefficients(coefficients, patches)