tiling.py 5.34 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Utilities for spatial tiling of periodic convolutions into batch dimensions.

``layout`` tuple indicates how the corresponding spatial dimensions are layed
out in space. In 2D:
- `(1, 1)` indicates no tiling.
- `(4, 2)` indicates 4 x-tiles and 2 y-tiles
- `(16, 8)` indicates 16 x-tiles and 8 y-tiles

Tiling is helpful for getting the highest performance convolutions on TPU. Per
the TPU performance guide [1], batch dimensions on TPUs are tiled to multiples
of 8 or 128. Thus the product of all elements in `layout` should typically be
either 8 or 128.

[1] https://cloud.google.com/tpu/docs/performance-guide.
"""
import functools
from typing import Callable, Sequence, Tuple

import einops
import jax
from jax import lax
import jax.numpy as jnp


Array = jnp.ndarray


def _prod(xs):
  # backport of math.prod() from Python 3.8+
  result = 1
  for x in xs:
    result *= x
  return result


def _verify_layout(array, layout):
  if array.ndim != len(layout) + 2 or array.shape[0] != _prod(layout):
    raise ValueError(
        f"array shape does not match layout: {array.shape} vs {layout}")


def _layout_to_dict(layout):
  return dict(zip(["bx", "by", "bz"], layout))


def _tile_roll(array, layout, shift, axis):
  """Roll along the "tiled" dimension."""
  _verify_layout(array, layout)
  sizes = _layout_to_dict(layout)
  if len(layout) == 1:
    array = jnp.roll(array, shift, axis=axis)
  elif len(layout) == 2:
    array = einops.rearrange(array, "(bx by) ... -> bx by ...", **sizes)
    array = jnp.roll(array, shift, axis=axis)
    array = einops.rearrange(array, "bx by ... -> (bx by) ...", **sizes)
  elif len(layout) == 3:
    array = einops.rearrange(array, "(bx by bz) ... -> bx by bz ...", **sizes)
    array = jnp.roll(array, shift, axis=axis)
    array = einops.rearrange(array, "bx by bz ... -> (bx by bz) ...", **sizes)
  else:
    raise NotImplementedError
  return array


def _halo_pad_1d(array, layout, axis, padding=(1, 1)):
  """Pad for halo-exchange along a single array dimension."""
  pad_left, pad_right = padding
  spatial_axis = axis + 1
  pieces = []

  if pad_left:
    # Note: importantly, dynamic_slice_in_dim raises an error for out of bounds
    # access, which catches the edge case where a single array is insufficient
    # padding.
    start = array.shape[spatial_axis] - pad_left
    input_right = lax.dynamic_slice_in_dim(array, start, pad_left, spatial_axis)
    output_left = _tile_roll(input_right, layout, shift=+1, axis=axis)
    pieces.append(output_left)

  pieces.append(array)

  if pad_right:
    start = 0
    input_left = lax.dynamic_slice_in_dim(array, start, pad_right, spatial_axis)
    output_right = _tile_roll(input_left, layout, shift=-1, axis=axis)
    pieces.append(output_right)

  return jnp.concatenate(pieces, axis=spatial_axis)


@functools.partial(jax.jit, static_argnums=(1, 2,))
def _halo_exchange_pad(array: Array, layout: Tuple[int, ...],
                       padding: Tuple[Tuple[int, int]]) -> Array:
  """Pad with halo-exchange in N-dimensions."""
  _verify_layout(array, layout)
  if len(layout) != len(padding):
    raise ValueError(f"invalid padding: {padding}")
  out = array
  for axis, pad in enumerate(padding):
    out = _halo_pad_1d(out, layout, axis, pad)
  return out


def halo_exchange_pad(
    array: Array,
    layout: Tuple[int, ...],
    padding: Sequence[Tuple[int, int]],
) -> Array:
  """Pad with halo-exchange in N-dimensions."""
  return _halo_exchange_pad(
      array, layout,
      tuple(map(tuple, padding)))


@functools.partial(jax.jit, static_argnums=(1,))
def space_to_batch(array: Array, layout: Tuple[int, ...]) -> Array:
  """Rearrange from space to batch dimensions."""
  sizes = _layout_to_dict(layout)
  if len(layout) == 1:
    path = "(bx x) c -> (bx) x c"
  elif len(layout) == 2:
    path = "(bx x) (by y) c -> (bx by) x y c"
  elif len(layout) == 3:
    path = "(bx x) (by y) (bz z) c -> (bx by bz) x y z c"
  else:
    raise NotImplementedError
  return einops.rearrange(array, path, **sizes)


@functools.partial(jax.jit, static_argnums=(1,))
def batch_to_space(array: Array, layout: Tuple[int, ...]) -> Array:
  """Rearrange from batch to space dimensions."""
  sizes = _layout_to_dict(layout)
  if len(layout) == 1:
    path = "(bx) x c -> (bx x) c"
  elif len(layout) == 2:
    path = "(bx by) x y c -> (bx x) (by y) c"
  elif len(layout) == 3:
    path = "(bx by bz) x y z c-> (bx x) (by y) (bz z) c"
  else:
    raise NotImplementedError
  return einops.rearrange(array, path, **sizes)


def apply_convolution(
    conv: Callable[[Array], Array],
    inputs: Array,
    layout: Tuple[int, ...],
    padding: Sequence[Tuple[int, ...]],
) -> Array:
  """Apply a valid convolution with tiling and periodic boundary conditions.

  Args:
    conv: function that calculates a convolution with valid boundary conditions
      when applied to an array of shape [batch, [spatial dims], channel].
    inputs: array of shape [[spatial dims], channel].
    layout: tiling layout for implementing the operation.
    padding: amount of periodic padding to add before and after each spatial
      dimension.

  Returns:
    Convolved array.
  """
  if layout is None:
    # TODO(shoyer): replace this with some sensible heuristic
    layout = (1,) * len(padding)
  tiled = space_to_batch(inputs, layout)
  padded = halo_exchange_pad(tiled, layout, padding)
  convolved = conv(padded)
  output = batch_to_space(convolved, layout)
  return output