mapping.py 7.78 KB
Newer Older
Augustin-Zidek's avatar
Augustin-Zidek committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialized mapping functions."""

import functools
18
import inspect
Augustin-Zidek's avatar
Augustin-Zidek committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

from typing import Any, Callable, Optional, Sequence, Union

import haiku as hk
import jax
import jax.numpy as jnp


PYTREE = Any
PYTREE_JAX_ARRAY = Any

partial = functools.partial
PROXY = object()


def _maybe_slice(array, i, slice_size, axis):
  if axis is PROXY:
    return array
  else:
    return jax.lax.dynamic_slice_in_dim(
        array, i, slice_size=slice_size, axis=axis)


def _maybe_get_size(array, axis):
  if axis == PROXY:
    return -1
  else:
    return array.shape[axis]


def _expand_axes(axes, values, name='sharded_apply'):
Augustin Zidek's avatar
Augustin Zidek committed
50
  values_tree_def = jax.tree_util.tree_flatten(values)[1]
Augustin-Zidek's avatar
Augustin-Zidek committed
51
52
53
  flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
  # Replace None's with PROXY
  flat_axes = [PROXY if x is None else x for x in flat_axes]
Augustin Zidek's avatar
Augustin Zidek committed
54
  return jax.tree_util.tree_unflatten(values_tree_def, flat_axes)
Augustin-Zidek's avatar
Augustin-Zidek committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78


def sharded_map(
    fun: Callable[..., PYTREE_JAX_ARRAY],
    shard_size: Union[int, None] = 1,
    in_axes: Union[int, PYTREE] = 0,
    out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
  """Sharded vmap.

  Maps `fun` over axes, in a way similar to vmap, but does so in shards of
  `shard_size`. This allows a smooth trade-off between memory usage
  (as in a plain map) vs higher throughput (as in a vmap).

  Args:
    fun: Function to apply smap transform to.
    shard_size: Integer denoting shard size.
    in_axes: Either integer or pytree describing which axis to map over for each
      input to `fun`, None denotes broadcasting.
    out_axes: integer or pytree denoting to what axis in the output the mapped
      over axis maps.

  Returns:
    function with smap applied.
  """
79
80
81
82
83
  if 'split_rng' in inspect.signature(hk.vmap).parameters:
    vmapped_fun = hk.vmap(fun, in_axes, out_axes, split_rng=False)
  else:
    # TODO(tomhennigan): Remove this when older versions of Haiku aren't used.
    vmapped_fun = hk.vmap(fun, in_axes, out_axes)
Augustin-Zidek's avatar
Augustin-Zidek committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)


def sharded_apply(
    fun: Callable[..., PYTREE_JAX_ARRAY],  # pylint: disable=g-bare-generic
    shard_size: Union[int, None] = 1,
    in_axes: Union[int, PYTREE] = 0,
    out_axes: Union[int, PYTREE] = 0,
    new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
  """Sharded apply.

  Applies `fun` over shards to axes, in a way similar to vmap,
  but does so in shards of `shard_size`. Shards are stacked after.
  This allows a smooth trade-off between
  memory usage (as in a plain map) vs higher throughput (as in a vmap).

  Args:
    fun: Function to apply smap transform to.
    shard_size: Integer denoting shard size.
    in_axes: Either integer or pytree describing which axis to map over for each
      input to `fun`, None denotes broadcasting.
    out_axes: integer or pytree denoting to what axis in the output the mapped
      over axis maps.
    new_out_axes: whether to stack outputs on new axes. This assumes that the
      output sizes for each shard (including the possible remainder shard) are
      the same.

  Returns:
    function with smap applied.
  """
  docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
            'but with additional array axes over which {fun} is mapped.')
  if new_out_axes:
    raise NotImplementedError('New output axes not yet implemented.')

  # shard size None denotes no sharding
  if shard_size is None:
    return fun

  @jax.util.wraps(fun, docstr=docstr)
  def mapped_fn(*args):
    # Expand in axes and Determine Loop range
    in_axes_ = _expand_axes(in_axes, args)

128
    in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
Augustin Zidek's avatar
Augustin Zidek committed
129
    flat_sizes = jax.tree_util.tree_flatten(in_sizes)[0]
Augustin-Zidek's avatar
Augustin-Zidek committed
130
131
132
133
134
135
136
137
138
139
    in_size = max(flat_sizes)
    assert all(i in {in_size, -1} for i in flat_sizes)

    num_extra_shards = (in_size - 1) // shard_size

    # Fix Up if necessary
    last_shard_size = in_size % shard_size
    last_shard_size = shard_size if last_shard_size == 0 else last_shard_size

    def apply_fun_to_slice(slice_start, slice_size):
140
      input_slice = jax.tree_map(
Augustin-Zidek's avatar
Augustin-Zidek committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
          lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
                                          ), args, in_axes_)
      return fun(*input_slice)

    remainder_shape_dtype = hk.eval_shape(
        partial(apply_fun_to_slice, 0, last_shard_size))
    out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
    out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
    out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)

    if num_extra_shards > 0:
      regular_shard_shape_dtype = hk.eval_shape(
          partial(apply_fun_to_slice, 0, shard_size))
      shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)

      def make_output_shape(axis, shard_shape, remainder_shape):
        return shard_shape[:axis] + (
            shard_shape[axis] * num_extra_shards +
            remainder_shape[axis],) + shard_shape[axis + 1:]

161
162
      out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
                                out_shapes)
Augustin-Zidek's avatar
Augustin-Zidek committed
163
164

    # Calls dynamic Update slice with different argument order
165
    # This is here since tree_map only works with positional arguments
Augustin-Zidek's avatar
Augustin-Zidek committed
166
167
168
169
170
171
172
    def dynamic_update_slice_in_dim(full_array, update, axis, i):
      return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)

    def compute_shard(outputs, slice_start, slice_size):
      slice_out = apply_fun_to_slice(slice_start, slice_size)
      update_slice = partial(
          dynamic_update_slice_in_dim, i=slice_start)
173
      return jax.tree_map(update_slice, outputs, slice_out, out_axes_)
Augustin-Zidek's avatar
Augustin-Zidek committed
174
175
176
177
178
179
180
181
182
183

    def scan_iteration(outputs, i):
      new_outputs = compute_shard(outputs, i, shard_size)
      return new_outputs, ()

    slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)

    def allocate_buffer(dtype, shape):
      return jnp.zeros(shape, dtype=dtype)

184
    outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)
Augustin-Zidek's avatar
Augustin-Zidek committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    if slice_starts.shape[0] > 0:
      outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)

    if last_shard_size != shard_size:
      remainder_start = in_size - last_shard_size
      outputs = compute_shard(outputs, remainder_start, last_shard_size)

    return outputs

  return mapped_fn


def inference_subbatch(
    module: Callable[..., PYTREE_JAX_ARRAY],
    subbatch_size: int,
    batched_args: Sequence[PYTREE_JAX_ARRAY],
    nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
    low_memory: bool = True,
    input_subbatch_dim: int = 0,
    output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
  """Run through subbatches (like batch apply but with split and concat)."""
  assert len(batched_args) > 0  # pylint: disable=g-explicit-length-test

  if not low_memory:
    args = list(batched_args) + list(nonbatched_args)
    return module(*args)

  if output_subbatch_dim is None:
    output_subbatch_dim = input_subbatch_dim

  def run_module(*batched_args):
    args = list(batched_args) + list(nonbatched_args)
    return module(*args)
  sharded_module = sharded_apply(run_module,
                                 shard_size=subbatch_size,
                                 in_axes=input_subbatch_dim,
                                 out_axes=output_subbatch_dim)
  return sharded_module(*batched_args)