array_utils.py 16.8 KB
Newer Older
mashun1's avatar
jax-cfd  
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility methods for manipulating array-like objects."""

from typing import Any, Callable, List, Tuple, Union

import jax
import jax.numpy as jnp
from jax_cfd.base import boundaries
from jax_cfd.base import grids
import numpy as np
import scipy.linalg

# There is currently no good way to indicate a jax "pytree" with arrays at its
# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more
# information about PyTrees and https://github.com/google/jax/issues/3340 for
# discussion of this issue.
PyTree = Any
Array = Union[np.ndarray, jax.Array]


def _normalize_axis(axis: int, ndim: int) -> int:
  """Validates and returns positive `axis` value."""
  if not -ndim <= axis < ndim:
    raise ValueError(f'invalid axis {axis} for ndim {ndim}')
  if axis < 0:
    axis += ndim
  return axis


def slice_along_axis(
    inputs: PyTree,
    axis: int,
    idx: Union[slice, int],
    expect_same_dims: bool = True
) -> PyTree:
  """Returns slice of `inputs` defined by `idx` along axis `axis`.

  Args:
    inputs: array or a tuple of arrays to slice.
    axis: axis along which to slice the `inputs`.
    idx: index or slice along axis `axis` that is returned.
    expect_same_dims: whether all arrays should have same number of dimensions.

  Returns:
    Slice of `inputs` defined by `idx` along axis `axis`.
  """
  # arrays, tree_def = jax.tree_util.flatten(inputs)
  arrays, tree_def = jax.tree_util.tree_flatten(inputs)
  ndims = set(a.ndim for a in arrays)
  if expect_same_dims and len(ndims) != 1:
    raise ValueError('arrays in `inputs` expected to have same ndims, but have '
                     f'{ndims}. To allow this, pass expect_same_dims=False')
  sliced = []
  for array in arrays:
    ndim = array.ndim
    slc = tuple(idx if j == _normalize_axis(axis, ndim) else slice(None)
                for j in range(ndim))
    sliced.append(array[slc])
  return jax.tree_util.tree_unflatten(tree_def, sliced)


def split_along_axis(
    inputs: PyTree,
    split_idx: int,
    axis: int,
    expect_same_dims: bool = True
) -> Tuple[PyTree, PyTree]:
  """Returns a tuple of slices of `inputs` split along `axis` at `split_idx`.

  Args:
    inputs: pytree of arrays to split.
    split_idx: index along `axis` where the second split starts.
    axis: axis along which to split the `inputs`.
    expect_same_dims: whether all arrays should have same number of dimensions.

  Returns:
    Tuple of slices of `inputs` split along `axis` at `split_idx`.
  """

  first_slice = slice_along_axis(
      inputs, axis, slice(0, split_idx), expect_same_dims)
  second_slice = slice_along_axis(
      inputs, axis, slice(split_idx, None), expect_same_dims)
  return first_slice, second_slice


def split_axis(
    inputs: PyTree,
    axis: int,
    keep_dims: bool = False
) -> Tuple[PyTree, ...]:
  """Splits the arrays in `inputs` along `axis`.

  Args:
    inputs: pytree to be split.
    axis: axis along which to split the `inputs`.
    keep_dims: whether to keep `axis` dimension.

  Returns:
    Tuple of pytrees that correspond to slices of `inputs` along `axis`. The
    `axis` dimension is removed if `squeeze is set to True.

  Raises:
    ValueError: if arrays in `inputs` don't have unique size along `axis`.
  """
  arrays, tree_def = jax.tree_util.flatten(inputs)
  axis_shapes = set(a.shape[axis] for a in arrays)
  if len(axis_shapes) != 1:
    raise ValueError(f'Arrays must have equal sized axis but got {axis_shapes}')
  axis_shape, = axis_shapes
  splits = [jnp.split(a, axis_shape, axis=axis) for a in arrays]
  if not keep_dims:
    splits = jax.tree_util.tree_map(lambda a: jnp.squeeze(a, axis), splits)
  splits = zip(*splits)
  return tuple(jax.tree_util.unflatten(tree_def, leaves) for leaves in splits)


def concat_along_axis(pytrees, axis):
  """Concatenates `pytrees` along `axis`."""
  concat_leaves_fn = lambda *args: jnp.concatenate(args, axis)
  return jax.tree_util.tree_map(concat_leaves_fn, *pytrees)


def block_reduce(
    array: Array,
    block_size: Tuple[int, ...],
    reduction_fn: Callable[[Array], Array]
) -> Array:
  """Breaks `array` into `block_size` pieces and applies `f` to each.

  This function is equivalent to `scikit-image.measure.block_reduce`:
  https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.block_reduce

  Args:
    array: an array.
    block_size: the size of the blocks on which the reduction is performed.
      Must evenly divide `array.shape`.
    reduction_fn: a reduction function that will be applied to each block of
      size `block_size`.
  Returns:
    The result of applying `f` to each block of size `block_size`.
  """
  new_shape = []
  for b, s in zip(block_size, array.shape):
    multiple, residual = divmod(s, b)
    if residual != 0:
      raise ValueError('`block_size` must divide `array.shape`;'
                       f'got {block_size}, {array.shape}.')
    new_shape += [multiple, b]
  multiple_axis_reduction_fn = reduction_fn
  for j in reversed(range(array.ndim)):
    multiple_axis_reduction_fn = jax.vmap(multiple_axis_reduction_fn, j)
  return multiple_axis_reduction_fn(array.reshape(new_shape))


def laplacian_matrix(size: int, step: float) -> np.ndarray:
  """Create 1D Laplacian operator matrix, with periodic BC."""
  column = np.zeros(size)
  column[0] = -2 / step**2
  column[1] = column[-1] = 1 / step**2
  return scipy.linalg.circulant(column)


def _laplacian_boundary_dirichlet_cell_centered(laplacians: List[Array],
                                                grid: grids.Grid, axis: int,
                                                side: str) -> None:
  """Converts 1d laplacian matrix to satisfy dirichlet homogeneous bc.

  laplacians[i] contains a 3 point stencil matrix L that approximates
  d^2/dx_i^2.
  For detailed documentation on laplacians input type see
  array_utils.laplacian_matrix.
  The default return of array_utils.laplacian_matrix makes a matrix for
  periodic boundary. For dirichlet boundary, the correct equation is
  L(u_interior) = rhs_interior and BL_boundary = u_fixed_boundary. So
  laplacian_boundary_dirichlet restricts the matrix L to
  interior points only.

  This function assumes RHS has cell-centered offset.
  Args:
    laplacians: list of 1d laplacians
    grid: grid object
    axis: axis along which to impose dirichlet bc.
    side: lower or upper side to assign boundary to.

  Returns:
    updated list of 1d laplacians.
  """
  # This function assumes homogeneous boundary, in which case if the offset
  # is 0.5 away from the wall, the ghost cell value u[0] = -u[1]. So the
  # 3 point stencil [1 -2 1] * [u[0] u[1] u[2]] = -3 u[1] + u[2].
  if side == 'lower':
    laplacians[axis][0, 0] = laplacians[axis][0, 0] - 1 / grid.step[axis]**2
  else:
    laplacians[axis][-1, -1] = laplacians[axis][-1, -1] - 1 / grid.step[axis]**2
  # deletes corner dependencies on the "looped-around" part.
  # this should be done irrespective of which side, since one boundary cannot
  # be periodic while the other is.
  laplacians[axis][0, -1] = 0.0
  laplacians[axis][-1, 0] = 0.0
  return


def _laplacian_boundary_neumann_cell_centered(laplacians: List[Array],
                                              grid: grids.Grid, axis: int,
                                              side: str) -> None:
  """Converts 1d laplacian matrix to satisfy neumann homogeneous bc.

  This function assumes the RHS will have a cell-centered offset.
  Neumann boundaries are not defined for edge-aligned offsets elsewhere in the
  code.

  Args:
    laplacians: list of 1d laplacians
    grid: grid object
    axis: axis along which to impose dirichlet bc.
    side: which boundary side to convert to neumann homogeneous bc.

  Returns:
    updated list of 1d laplacians.
  """
  if side == 'lower':
    laplacians[axis][0, 0] = laplacians[axis][0, 0] + 1 / grid.step[axis]**2
  else:
    laplacians[axis][-1, -1] = laplacians[axis][-1, -1] + 1 / grid.step[axis]**2
  # deletes corner dependencies on the "looped-around" part.
  # this should be done irrespective of which side, since one boundary cannot
  # be periodic while the other is.
  laplacians[axis][0, -1] = 0.0
  laplacians[axis][-1, 0] = 0.0
  return


def laplacian_matrix_w_boundaries(
    grid: grids.Grid,
    offset: Tuple[float, ...],
    bc: boundaries.BoundaryConditions,
) -> List[Array]:
  """Returns 1d laplacians that satisfy boundary conditions bc on grid.

  Given grid, offset and boundary conditions, returns a list of 1 laplacians
  (one along each axis).

  Currently, only homogeneous or periodic boundary conditions are supported.

  Args:
    grid: The grid used to construct the laplacian.
    offset: The offset of the variable on which laplacian acts.
    bc: the boundary condition of the variable on which the laplacian acts.

  Returns:
    A list of 1d laplacians.
  """
  if not isinstance(bc, boundaries.ConstantBoundaryConditions):
    raise NotImplementedError(
        f'Explicit laplacians are not implemented for {bc}.')
  laplacians = list(map(laplacian_matrix, grid.shape, grid.step))
  for axis in range(grid.ndim):
    if np.isclose(offset[axis], 0.5):
      for i, side in enumerate(['lower', 'upper']):  # lower and upper boundary
        if bc.types[axis][i] == boundaries.BCType.NEUMANN:
          _laplacian_boundary_neumann_cell_centered(
              laplacians, grid, axis, side)
        elif bc.types[axis][i] == boundaries.BCType.DIRICHLET:
          _laplacian_boundary_dirichlet_cell_centered(
              laplacians, grid, axis, side)
    if np.isclose(offset[axis] % 1, 0.):
      if bc.types[axis][0] == boundaries.BCType.DIRICHLET and bc.types[
          axis][1] == boundaries.BCType.DIRICHLET:
        # This function assumes homogeneous boundary and acts on the interior.
        # Thus, the laplacian can be cut off past the edge.
        # The interior grid has one fewer grid cell than the actual grid, so
        # the size of the laplacian should be reduced.
        laplacians[axis] = laplacians[axis][:-1, :-1]
      elif boundaries.BCType.NEUMANN in bc.types[axis]:
        raise NotImplementedError(
            'edge-aligned Neumann boundaries are not implemented.')
  return laplacians


def unstack(array, axis):
  """Returns a tuple of slices of `array` along axis `axis`."""
  squeeze_fn = lambda x: jnp.squeeze(x, axis=axis)
  return tuple(squeeze_fn(x) for x in jnp.split(array, array.shape[axis], axis))


def gram_schmidt_qr(
    matrix: Array,
    precision: jax.lax.Precision = jax.lax.Precision.HIGHEST
) -> Tuple[Array, Array]:
  """Computes QR decomposition using gramm-schmidt orthogonalization.

  This algorithm is suitable for tall matrices with very few columns. This
  method is more memory efficient compared to `jnp.linalg.qr`, but is less
  numerically stable, especially for matrices with many columns.

  Args:
    matrix: 2D array representing the matrix to be decomposed into orthogonal
      and upper triangular.
    precision: numerical precision for matrix multplication. Only relevant on
      TPUs.

  Returns:
    tuple of matrix Q whose columns are orthonormal and R that is upper
    triangular.
  """

  def orthogonalize(vector, others):
    """Returns the orthogonal component of `vector` with respect to `others`."""
    if not others:
      return vector / jnp.linalg.norm(vector)
    orthogonalize_step = lambda c, x: tuple([c - jnp.dot(c, x) * x, None])
    vector, _ = jax.lax.scan(orthogonalize_step, vector, jnp.stack(others))
    return vector / jnp.linalg.norm(vector)

  num_columns = matrix.shape[1]
  columns = unstack(matrix, axis=1)
  q_columns = []
  r_rows = []
  for vec_index, column in enumerate(columns):
    next_q_column = orthogonalize(column, q_columns)
    r_rows.append(jnp.asarray([
        jnp.dot(columns[i], next_q_column) if i >= vec_index else 0.
        for i in range(num_columns)]))
    q_columns.append(next_q_column)
  q = jnp.stack(q_columns, axis=1)
  r = jnp.stack(r_rows)
  # permute q columns to make entries of r on the diagonal positive.
  d = jnp.diag(jnp.sign(jnp.diagonal(r)))
  q = jnp.matmul(q, d, precision=precision)
  r = jnp.matmul(d, r, precision=precision)
  return q, r


def interp1d(  # pytype: disable=annotation-type-mismatch  # jnp-type
    x: Array,
    y: Array,
    axis: int = -1,
    fill_value: Union[str, Array] = jnp.nan,
    assume_sorted: bool = True,
) -> Callable[[Array], jax.Array]:
  """Build an interpolation function to approximate `y = f(x)`.

  x and y are arrays of values used to approximate some function f: y = f(x).
  This returns a function that uses linear interpolation to approximate f
  evaluated at new points.

  ```
  x = jnp.linspace(0, 10)
  y = jnp.sin(x)
  f = interp1d(x, y)

  x_new = 1.23
  f(x_new)
  ==> Approximately sin(1.23).

  x_new = ...  # Shape (4, 5) array
  f(x_new)
  ==> Shape (4, 5) array, approximating sin(x_new).
  ```

  Args:
    x: Length N 1-D array of values.
    y: Shape (..., N, ...) array of values corresponding to f(x).
    axis: Specifies the axis of y along which to interpolate. Interpolation
      defaults to the last axis of y.
    fill_value: Scalar array or string. If array, this value will be used to
      fill in for requested points outside of the data range. If not provided,
      then the default is NaN. If "extrapolate", then linear extrapolation is
      used.  If "constant_extrapolate", then the function is extended as a
      constant.
    assume_sorted: Whether to assume x is sorted. If True, x must be
      monotonically increasing. If False, this function sorts x and reorders
      y appropriately.

  Returns:
    Callable mapping array x_new to values y_new, where
      y_new.shape = y.shape[:axis] + x_new.shape + y.shape[axis + 1:]
  """
  allowed_fill_value_strs = {'constant_extrapolate', 'extrapolate'}
  if isinstance(fill_value, str):
    if fill_value not in allowed_fill_value_strs:
      raise ValueError(
          f'`fill_value` "{fill_value}" not in {allowed_fill_value_strs}')
  else:
    fill_value = np.asarray(fill_value)
    if fill_value.ndim > 0:
      raise ValueError(f'Only scalar `fill_value` supported. '
                       f'Found shape: {fill_value.shape}')

  x = jnp.asarray(x)
  if x.ndim != 1:
    raise ValueError(f'Expected `x` to be 1D. Found shape {x.shape}')
  if x.size < 2:
    raise ValueError(f'`x` must have at least 2 entries. Found shape {x.shape}')
  n_x = x.shape[0]
  if not assume_sorted:
    ind = jnp.argsort(x)
    x = x[ind]
    y = jnp.take(y, ind, axis=axis)

  y = jnp.asarray(y)
  if y.ndim < 1:
    raise ValueError(f'Expected `y` to have rank >= 1. Found shape {y.shape}')

  if x.size != y.shape[axis]:
    raise ValueError(
        f'x and y arrays must be equal in length along interpolation axis. '
        f'Found x.shape={x.shape} and y.shape={y.shape} and axis={axis}')

  axis = _normalize_axis(axis, ndim=y.ndim)

  def interp_func(x_new: jax.Array) -> jax.Array:
    """Implementation of the interpolation function."""
    x_new = jnp.asarray(x_new)

    # We will flatten x_new, then interpolate, then reshape the output.
    x_new_shape = x_new.shape
    x_new = jnp.ravel(x_new)

    # This construction of indices ensures that below_idx < above_idx, even at
    # x_new = x[i] exactly for some i.
    x_new_clipped = jnp.clip(x_new, np.min(x), np.max(x))
    above_idx = jnp.minimum(n_x - 1,
                            jnp.searchsorted(x, x_new_clipped, side='right'))
    below_idx = jnp.maximum(0, above_idx - 1)

    def expand(array):
      """Add singletons to rightmost dims of `array` so it bcasts with y."""
      array = jnp.asarray(array)
      return jnp.reshape(array, array.shape + (1,) * (y.ndim - axis - 1))

    x_above = jnp.take(x, above_idx)
    x_below = jnp.take(x, below_idx)
    y_above = jnp.take(y, above_idx, axis=axis)
    y_below = jnp.take(y, below_idx, axis=axis)
    slope = (y_above - y_below) / expand(x_above - x_below)

    if isinstance(fill_value, str) and fill_value == 'extrapolate':
      delta_x = expand(x_new - x_below)
      y_new = y_below + delta_x * slope
    elif isinstance(fill_value, str) and fill_value == 'constant_extrapolate':
      delta_x = expand(x_new_clipped - x_below)
      y_new = y_below + delta_x * slope
    else:  # Else fill_value is an Array.
      delta_x = expand(x_new - x_below)
      fill_value_ = expand(fill_value)
      y_new = y_below + delta_x * slope
      y_new = jnp.where(
          (delta_x < 0) | (delta_x > expand(x_above - x_below)),
          fill_value_, y_new)
    return jnp.reshape(
        y_new, y_new.shape[:axis] + x_new_shape + y_new.shape[axis + 1:])

  return interp_func