utils.py 45 KB
Newer Older
1
2
3
4
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

5
6
7
import functools
import operator
from typing import Any, Callable, Tuple, Sequence, Union, Iterable, Optional
8

9
import jax
10
11
import jax.numpy as jnp
import numpy as np
12
13
14
15
16
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
17
18
19
20
21
22
23
24
25
26

PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
                                                                       lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def is_devices_enough(required):
    return len(jax.devices()) >= required


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
    if fn_or_string == 'linear':
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
    """Combine attention masks.

  Args:
    *masks: set of attention mask arguments to combine, some can be None.
    dtype: final mask dtype

  Returns:
    Combined mask, reduced by logical and, returns None if no masks given.
  """
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = jnp.logical_and(mask, other_mask)
    return mask.astype(dtype)


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases.

  Args:
    *masks: set of attention bias arguments to combine, some can be None.

  Returns:
    Combined mask, reduced by summation, returns None if no masks given.
  """
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
    assert all(map(lambda x: x.ndim == masks[0].ndim,
                   masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


def dot_product_attention(query: Array,
                          key: Array,
                          value: Array,
                          transpose_batch_sequence: bool,
                          bias: Optional[Array] = None,
                          dropout_rng: Optional[PRNGKey] = None,
                          dropout_rate: float = 0.,
                          deterministic: bool = False,
                          dtype: DType = jnp.float32,
                          float32_logits: bool = False):
    """Computes dot-product attention given query, key, and value.

  This is the core function for applying attention based on
  https://arxiv.org/abs/1706.03762. It calculates the attention weights given
  query and key and combines the values using the attention weights.

  Args:
    query: queries for calculating attention with shape of `[batch, q_length,
      num_heads, qk_depth_per_head]`.
    key: keys for calculating attention with shape of `[batch, kv_length,
      num_heads, qk_depth_per_head]`.
    value: values to be used in attention with shape of `[batch, kv_length,
      num_heads, v_depth_per_head]`.
    bias: bias for the attention weights. This should be broadcastable to the
      shape `[batch, num_heads, q_length, kv_length]` This can be used for
      incorporating causal masks, padding masks, proximity bias, etc.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    dtype: the dtype of the computation (default: float32)
    float32_logits: bool, if True then compute logits in float32 to avoid
      numerical issues with bfloat16.

  Returns:
    Output of shape `[batch, length, num_heads, v_depth_per_head]`.
  """
    assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
    batch_dim = 1 if transpose_batch_sequence else 0
    assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
        'q, k, v batch dims must match.')
    assert query.shape[-2] == key.shape[-2] == value.shape[-2], ('q, k, v num_heads must match.')
    sequence_dim = 0 if transpose_batch_sequence else 1
    assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
    assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'

    # Casting logits and softmax computation for float32 for model stability.
    if float32_logits:
        query = query.astype(jnp.float32)
        key = key.astype(jnp.float32)

    # `attn_weights`: [batch, num_heads, q_length, kv_length]
    if transpose_batch_sequence:
        attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
    else:
        attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

    # Apply attention bias: masking, dropout, proximity bias, etc.
    if bias is not None:
        attn_weights = attn_weights + bias.astype(attn_weights.dtype)

    # Normalize the attention weights across `kv_length` dimension.
    attn_weights = jax_nn.softmax(attn_weights).astype(dtype)

    # Apply attention dropout.
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        # T5 broadcasts along the "length" dim, but unclear which one that
        # corresponds to in positional dimensions here, assuming query dim.
        dropout_shape = list(attn_weights.shape)
        keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
        multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
        attn_weights = attn_weights * multiplier

    # Take the linear combination of `value`.
    if transpose_batch_sequence:
        return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)

    return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)


class DenseGeneral(nn.Module):
    """A linear transformation with flexible axes and FP8 support.

        Attributes:
        features: tuple with numbers of output features.
        axis: tuple with axes to apply the transformation on.
        dtype: the dtype of the computation (default: float32).
        kernel_init: initializer function for the weight matrix.
        use_bias: whether to add a bias to the output (default: False).
        bias_init: initializer function for the bias vector.
    """
    features: Union[Iterable[int], int]
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

        Args:
        inputs: The nd-array to be transformed.

        Returns:
        The transformed input.
        """
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
        kernel = nn_partitioning.param_with_axes('kernel',
                                                 self.kernel_init,
                                                 kernel_param_shape,
                                                 jnp.float32,
                                                 axes=self.kernel_axes)

        kernel = jnp.asarray(kernel, self.dtype)
        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
            bias = nn_partitioning.param_with_axes('bias',
                                                   self.bias_init, (self.features,),
                                                   self.dtype,
                                                   axes=self.bias_axes)
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))

        y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))

        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block.

  Attributes:
    intermediate_dim: Shared dimension of hidden layers.
    activations: Type of activations for each layer.  Each element is either
      'linear', a string function name in flax.linen, or a function.
    kernel_init: Kernel function, passed to the dense layers.
    deterministic: Whether the dropout layers should be deterministic.
    intermediate_dropout_rate: Dropout rate used after the intermediate layers.
    dtype: Type for the dense layer.
  """
    transpose_batch_sequence: bool
    intermediate_dim: int = 2048
    activations: Sequence[Union[str, Callable]] = ('relu',)
    kernel_init: Initializer = None
    intermediate_dropout_rate: float = 0.1
    dtype: Any = jnp.float32
    fuse_wi: bool = False

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs, deterministic: bool = False):
        """Applies Transformer MlpBlock module."""
        # Iterate over specified MLP input activation functions.
        # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.

        activations = []
        if self.fuse_wi:
            dense_name = 'wi'
            num_activations = len(self.activations)
            x = DenseGeneral(self.intermediate_dim * num_activations,
                             dtype=self.dtype,
                             kernel_init=self.kernel_init,
                             kernel_axes=('embed', 'mlp'),
                             name=dense_name)(inputs)
            x = jnp.split(x, num_activations, axis=-1)
            for idx, act_fn in enumerate(self.activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                activations.append(x_i)
        else:
            for idx, act_fn in enumerate(self.activations):
                dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}'
                x = DenseGeneral(self.intermediate_dim,
                                 dtype=self.dtype,
                                 kernel_init=self.kernel_init,
                                 kernel_axes=('embed', 'mlp'),
                                 name=dense_name)(inputs)
                x = _convert_to_activation_function(act_fn)(x)
                activations.append(x)

        # Take elementwise product of above intermediate activations.
        x = functools.reduce(operator.mul, activations)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
307
        dropout_broadcast_dims = (0,) if self.transpose_batch_sequence else (1,)
308
        # Apply dropout and final dense output projection.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
309
        x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=dropout_broadcast_dims)(
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            x, deterministic=deterministic)    # Broadcast along length.
        if self.transpose_batch_sequence:
            x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
        else:
            x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp'))
        output = DenseGeneral(inputs.shape[-1],
                              dtype=self.dtype,
                              kernel_init=self.kernel_init,
                              kernel_axes=('mlp', 'embed'),
                              name='wo')(x)
        return output


dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class MultiHeadAttention(nn.Module):
    """Multi-head dot-product attention.

    Attributes:
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      head_dim: dimension of each head.
      dtype: the dtype of the computation.
      dropout_rate: dropout rate
      kernel_init: initializer for the kernel of the Dense layers.
      float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
  """

    num_heads: int
    head_dim: int
    transpose_batch_sequence: bool
    dtype: DType = jnp.float32
    dropout_rate: float = 0.
    kernel_init: Initializer = None
    float32_logits: bool = False    # computes logits in float32 for stability.
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    fuse_qkv: bool = True

    def __post_init__(self):
        if self.kernel_init is None:
            self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
        super().__post_init__()

    @nn.compact
    def __call__(self,
                 inputs_q: Array,
                 inputs_kv: Array,
                 mask: Optional[Array] = None,
                 bias: Optional[Array] = None,
                 *,
                 decode: bool = False,
                 deterministic: bool = False) -> Array:
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    There are two modes: decoding and non-decoding (e.g., training). The mode is
    determined by `decode` argument. For decoding, this method is called twice,
    first to initialize the cache and then for an actual decoding process. The
    two calls are differentiated by the presence of 'cached_key' in the variable
    dict. In the cache initialization stage, the cache variables are initialized
    as zeros and will be filled in the subsequent decoding process.

    In the cache initialization call, `inputs_q` has a shape [batch, length,
    q_features] and `inputs_kv`: [batch, length, kv_features]. During the
    incremental decoding stage, query, key and value all have the shape [batch,
    1, qkv_features] corresponding to a single step.

    Args:
      inputs_q: input queries of shape `[batch, q_length, q_features]`.
      inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
      mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
      bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
      decode: Whether to prepare and use an autoregressive cache.
      deterministic: Disables dropout if set to True.

    Returns:
      output of shape `[batch, length, q_features]`.
    """
        projection = functools.partial(DenseGeneral,
                                       axis=-1,
                                       features=self.num_heads * self.head_dim,
                                       kernel_axes=('embed', 'joined_kv'),
                                       dtype=self.dtype)

        # NOTE: T5 does not explicitly rescale the attention logits by
        #       1/sqrt(depth_kq)!  This is folded into the initializers of the
        #       linear transformations, which is equivalent under Adafactor
        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
        query_init = lambda *args: self.kernel_init(*args) / (    # pylint: disable=unnecessary-lambda-assignment
            depth_scaling if self.scaled_query_init else 1.0)

        # Project inputs_q to multi-headed q/k/v
        # dimensions are then [batch, length, num_heads, head_dim]

        def qkv_init(key, shape, dtype):
            assert shape[-1] % 3 == 0

            q_shape = (shape[0], shape[1] // 3)
            k_shape = (shape[0], shape[1] // 3)
            v_shape = (shape[0], shape[1] // 3)

            q_kernel = query_init(key, q_shape, dtype)
            k_kernel = self.kernel_init(key, k_shape, dtype)    # pylint: disable=too-many-function-args
            v_kernel = self.kernel_init(key, v_shape, dtype)    # pylint: disable=too-many-function-args

            return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)

        if self.fuse_qkv:
            if inputs_q is inputs_kv:
                qkv_proj = DenseGeneral(axis=-1,
                                        features=self.num_heads * self.head_dim * 3,
                                        kernel_axes=('embed', 'joined_kv'),
                                        kernel_init=qkv_init,
                                        name='qkv',
                                        dtype=self.dtype)(inputs_kv)
                query, key, value = jnp.split(
                    qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
                    axis=-1)
                if self.scale_attn_logits:
                    query = query / depth_scaling
            else:
                query = projection(kernel_init=query_init, name='query')( \
                        (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
                kv_proj = DenseGeneral(axis=-1,
                                       features=self.num_heads * self.head_dim * 2,
                                       kernel_axes=('embed', 'joined_kv'),
                                       kernel_init=self.kernel_init,
                                       name='kv',
                                       dtype=self.dtype)(inputs_kv)
                key, value = jnp.split(kv_proj, [self.num_heads * self.head_dim], axis=-1)
        else:
            query = projection(kernel_init=query_init, name='query')( \
                    (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
            key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
            value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv)

        query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
        key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
        value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))

        if self.transpose_batch_sequence:
            query = nn_partitioning.with_sharding_constraint(query,
                                                             ('length', 'batch', 'heads', 'kv'))
            key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv'))
            value = nn_partitioning.with_sharding_constraint(value,
                                                             ('length', 'batch', 'heads', 'kv'))
        else:
            query = nn_partitioning.with_sharding_constraint(query,
                                                             ('batch', 'length', 'heads', 'kv'))
            key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
            value = nn_partitioning.with_sharding_constraint(value,
                                                             ('batch', 'length', 'heads', 'kv'))

        if decode:
            # Detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable('cache', 'cached_key')
            # The key and value have dimension [batch, length, num_heads, head_dim],
            # but we cache them as [batch, num_heads, head_dim, length] as a TPU
            # fusion optimization. This also enables the "scatter via one-hot
            # broadcast" trick, which means we do a one-hot broadcast instead of a
            # scatter/gather operations, resulting in a 3-4x speedup in practice.
            swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])    # pylint: disable=unnecessary-lambda-assignment
            cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
                                       key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
                                         value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
                batch, num_heads, head_dim, length = cached_key.value.shape
                # During fast autoregressive decoding, we feed one position at a time,
                # and cache the keys and values step by step.
                # Sanity shape check of cached key against input query.
                expected_shape = (batch, 1, num_heads, head_dim)
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        f"expected query shape {expected_shape} instead got {query.shape}.")

                # Create a OHE of the current index. NOTE: the index is increased below.
                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
                # In order to update the key, value caches with the current key and
                # value, we move the length axis to the back, similar to what we did for
                # the cached ones above.
                # Note these are currently the key and value of a single position, since
                # we feed one position at a time.
                one_token_key = jnp.moveaxis(key, -3, -1)
                one_token_value = jnp.moveaxis(value, -3, -1)
                # Update key, value caches with our new 1d spatial slices.
                # We implement an efficient scatter into the cache via one-hot
                # broadcast and addition.
                key = cached_key.value + one_token_key * one_hot_indices
                value = cached_value.value + one_token_value * one_hot_indices
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # Move the keys and values back to their original shapes.
                key = jnp.moveaxis(key, -1, -3)
                value = jnp.moveaxis(value, -1, -3)

                # Causal mask for cached decoder self-attention: our single query
                # position should only attend to those key positions that have already
                # been generated and cached, not the remaining zero elements.
                mask = combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(length) <= cur_index,
                # (1, 1, length) represent (head dim, query length, key length)
                # query length is 1 because during decoding we deal with one
                # index.
                # The same mask is applied to all batch elements and heads.
                        (batch, 1, 1, length)))

                # Grab the correct relative attention bias during decoding. This is
                # only required during single step decoding.
                if bias is not None:
                    # The bias is a full attention matrix, but during decoding we only
                    # have to take a slice of it.
                    # This is equivalent to bias[..., cur_index:cur_index+1, :].
                    bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
                                                       jnp.reshape(cur_index, (-1)), 1, -2)

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(mask > 0,
                                        jnp.full(mask.shape, 0.).astype(self.dtype),
                                        jnp.full(mask.shape, -1e10).astype(self.dtype))
        else:
            attention_bias = None

        # Add provided bias term (e.g. relative position embedding).
        if bias is not None:
            attention_bias = combine_biases(attention_bias, bias)

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        # Apply attention.
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  transpose_batch_sequence=self.transpose_batch_sequence,
                                  bias=attention_bias,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=self.dropout_rate,
                                  deterministic=deterministic,
                                  dtype=self.dtype,
                                  float32_logits=self.float32_logits)

        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        if self.transpose_batch_sequence:
            x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv'))
        else:
            x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv'))

        # Back to the original inputs dimensions.
        out = DenseGeneral(
            features=inputs_q.shape[-1],    # output dim is set to the input dim.
            axis=-1,
            kernel_init=self.kernel_init,
            kernel_axes=('joined_kv', 'embed'),
            dtype=self.dtype,
            name='out')(x)
        return out


class LayerNorm(nn.Module):
    """T5 Layer normalization operating on the last axis of the input data."""
    epsilon: float = 1e-6
    dtype: Any = jnp.float32
    layernorm_type: str = 'layernorm'
590
591
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
592
593
    bias_init: Initializer = nn.initializers.zeros

594
595
596
597
598
599
600
601
    def __post_init__(self):
        if self.scale_init is None:
            if not self.zero_centered_gamma:
                self.scale_init = nn.initializers.ones
            else:
                self.scale_init = nn.initializers.zeros
        super().__post_init__()

602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Applies layer normalization on the input."""

        x = jnp.asarray(x, jnp.float32)
        features = x.shape[-1]

        scale = nn_partitioning.param_with_axes('scale',
                                                self.scale_init, (features,),
                                                jnp.float32,
                                                axes=('embed',))
        scale = jnp.asarray(scale, self.dtype)

        if self.layernorm_type == 'layernorm':
            mean = jnp.mean(x, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
            y = (x - mean) * lax.rsqrt(var + self.epsilon)

            bias = nn_partitioning.param_with_axes('ln_bias',
                                                   self.bias_init, (features,),
                                                   jnp.float32,
                                                   axes=('embed',))
            bias = jnp.asarray(bias, self.dtype)

            y = jnp.asarray(y, self.dtype)
627
628
629
630
            if not self.zero_centered_gamma:
                z = y * scale + bias
            else:
                z = y * (scale + 1) + bias
631
632
        else:
            assert self.layernorm_type == 'rmsnorm'
633
            assert not self.zero_centered_gamma
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
            mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
            y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
            z = y * scale

        return z


class RelativePositionBiases(nn.Module):
    """Adds T5-style relative positional embeddings to the attention logits.

  Attributes:
    num_buckets: Number of buckets to bucket distances between key and query
      positions into.
    max_distance: Maximum distance before everything is lumped into the last
      distance bucket.
    num_heads: Number of heads in the attention layer. Each head will get a
      different relative position weighting.
    dtype: Type of arrays through this module.
    embedding_init: initializer for relative embedding table.
  """
    num_buckets: int
    max_distance: int
    num_heads: int
    dtype: Any
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init

    @staticmethod
    def _relative_position_bucket(relative_position,
                                  bidirectional=True,
                                  num_buckets=32,
                                  max_distance=128):
        """Translate relative position to a bucket number for relative attention.

    The relative position is defined as memory_position - query_position, i.e.
    the distance in tokens from the attending position to the attended-to
    position.  If bidirectional=False, then positive relative positions are
    invalid.
    We use smaller buckets for small absolute relative_position and larger
    buckets for larger absolute relative_positions.  All relative
    positions >=max_distance  map to the same bucket.  All relative
    positions <=-max_distance map to the same bucket.  This should allow for
    more graceful generalization to longer sequences than the model has been
    trained on.

    Args:
      relative_position: an int32 array
      bidirectional: a boolean - whether the attention is bidirectional
      num_buckets: an integer
      max_distance: an integer

    Returns:
      a Tensor with the same shape as relative_position, containing int32
        values in the range [0, num_buckets)
    """
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).astype(np.int32) * num_buckets
            n = np.abs(n)
        else:
            n = np.maximum(n, 0)
        # now n is in the range [0, inf)
        max_exact = num_buckets // 2
        is_small = n < max_exact
        val_if_large = max_exact + (
            np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
            np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32)
        val_if_large = np.minimum(val_if_large, num_buckets - 1)
        ret += np.where(is_small, n, val_if_large)
        return ret

    @nn.compact
    def __call__(self, qlen, klen, bidirectional=True):
        """Produce relative position embedding attention biases.

    Args:
      qlen: attention query length.
      klen: attention key length.
      bidirectional: whether to allow positive memory-query relative position
        embeddings.

    Returns:
      output: `(1, len, q_len, k_len)` attention bias
    """
        context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
        relative_position = memory_position - context_position    # shape (qlen, klen)
        rp_bucket = self._relative_position_bucket(relative_position,
                                                   bidirectional=bidirectional,
                                                   num_buckets=self.num_buckets,
                                                   max_distance=self.max_distance)
        relative_attention_bias = nn_partitioning.param_with_axes(
            'rel_embedding',
            self.embedding_init, (self.num_heads, self.num_buckets),
            jnp.float32,
            axes=('heads', 'relpos_buckets'))

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
        # Instead of using a slow gather, we create a leading-dimension one-hot
        # array from rp_bucket and use it to perform the gather-equivalent via a
        # contraction, i.e.:
        # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
        # This is equivalent to relative_attention_bias[:, rp_bucket]
        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
        # --> shape (qlen, klen, num_heads)
        values = lax.dot_general(
            relative_attention_bias,
            rp_bucket_one_hot,
            (
                ((1,), (0,)),    # rhs, lhs contracting dims
                ((), ())))    # no batched dims
        # Add a singleton batch dimension.
        # --> shape (1, num_heads, qlen, klen)
        return values[jnp.newaxis, ...]


class EncoderLayer(nn.Module):
    """Transformer encoder layer."""
    relative_embedding: nn.Module = None
    num_heads: int = 8
    head_dim: int = 64
    dropout_rate: float = 0.1
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
    mlp_activations: Sequence[str] = ('relu',)
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
    layernorm_type: str = 'layernorm'
767
    zero_centered_gamma: bool = False
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    output_layernorm: bool = False
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    fuse_mlp_wi: bool = False

    @nn.compact
    def __call__(self, inputs, encoder_mask=None, deterministic=False):
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

        if self.relative_embedding is None:
            rel_emb = RelativePositionBiases(num_buckets=32,
                                             max_distance=128,
                                             num_heads=self.num_heads,
                                             dtype=self.dtype,
                                             embedding_init=nn.initializers.variance_scaling(
                                                 1.0, 'fan_avg', 'uniform'),
                                             name='relpos_bias')
        else:
            rel_emb = self.relative_embedding
        encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)

        # Attention block.
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
            x = LayerNorm(layernorm_type=self.layernorm_type,
797
                          zero_centered_gamma=self.zero_centered_gamma,
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
                          dtype=self.dtype,
                          name="pre_attention_layer_norm")(inputs)

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        x = MultiHeadAttention(num_heads=self.num_heads,
                               dtype=self.dtype,
                               head_dim=self.head_dim,
                               transpose_batch_sequence=self.transpose_batch_sequence,
                               dropout_rate=self.dropout_rate,
                               float32_logits=self.float32_attention_logits,
                               scale_attn_logits=self.scale_attn_logits,
                               scaled_query_init=self.scaled_query_init,
                               fuse_qkv=self.fuse_qkv_params,
                               name='attention')(x,
                                                 x,
                                                 encoder_mask,
                                                 encoder_bias,
                                                 deterministic=deterministic)
        x = nn.Dropout(rate=self.dropout_rate,
                       broadcast_dims=(sequence_dim,))(x, deterministic=deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
        x = x + residual

        # MLP block.
        residual = x
        y = LayerNorm(layernorm_type=self.layernorm_type,
832
                      zero_centered_gamma=self.zero_centered_gamma,
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
                      dtype=self.dtype,
                      name='pre_mlp_layer_norm')(x)

        if self.apply_residual_connection_post_layernorm:
            residual = y

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        y = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
            intermediate_dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
            name='mlp',
        )(y, deterministic=deterministic)
        y = nn.Dropout(rate=self.dropout_rate,
                       broadcast_dims=(sequence_dim,))(y, deterministic=deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
            y = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(y, deterministic=deterministic)
        y = y + residual

        if self.output_layernorm:
            y = LayerNorm(layernorm_type=self.layernorm_type,
859
                          zero_centered_gamma=self.zero_centered_gamma,
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
                          dtype=self.dtype,
                          name="output_layer_norm")(y)
        return y


class DecoderLayer(nn.Module):
    """Transformer decoder layer that attends to the encoder."""
    relative_embedding: nn.Module = None
    num_heads: int = 8
    head_dim: int = 64
    dropout_rate: float = 0.1
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
    mlp_activations: Sequence[str] = ('relu',)
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
    layernorm_type: str = 'layernorm'
881
    zero_centered_gamma: bool = False
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    drop_path: float = 0.0
    fuse_qkv_params: bool = True
    fuse_mlp_wi: bool = False

    @nn.compact
    def __call__(self,
                 inputs,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 deterministic=False,
                 decode=False,
                 max_decode_length=None):

        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim
        l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
        if self.relative_embedding is None:
            rel_emb = RelativePositionBiases(num_buckets=32,
                                             max_distance=128,
                                             num_heads=self.num_heads,
                                             dtype=self.dtype,
                                             embedding_init=nn.initializers.variance_scaling(
                                                 1.0, 'fan_avg', 'uniform'),
                                             name='relpos_bias')
        else:
            rel_emb = self.relative_embedding
        decoder_bias = rel_emb(l, l, False)

        # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
            x = LayerNorm(layernorm_type=self.layernorm_type,
918
                          zero_centered_gamma=self.zero_centered_gamma,
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
                          dtype=self.dtype,
                          name="pre_self_attention_layer_norm")(inputs)

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # Self-attention block
        x = MultiHeadAttention(num_heads=self.num_heads,
                               dtype=self.dtype,
                               head_dim=self.head_dim,
                               transpose_batch_sequence=self.transpose_batch_sequence,
                               dropout_rate=self.dropout_rate,
                               float32_logits=self.float32_attention_logits,
                               scale_attn_logits=self.scale_attn_logits,
                               scaled_query_init=self.scaled_query_init,
                               fuse_qkv=self.fuse_qkv_params,
                               name='self_attention')(x,
                                                      x,
                                                      decoder_mask,
                                                      decoder_bias,
                                                      deterministic=deterministic,
                                                      decode=decode)
        x = nn.Dropout(rate=self.dropout_rate,
                       broadcast_dims=(sequence_dim,))(x, deterministic=deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
            x = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
        x = x + residual

        # Encoder-Decoder block.
        residual = x
        y = LayerNorm(layernorm_type=self.layernorm_type,
954
                      zero_centered_gamma=self.zero_centered_gamma,
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
                      dtype=self.dtype,
                      name='pre_cross_attention_layer_norm')(x)

        if self.apply_residual_connection_post_layernorm:
            residual = y
        y = MultiHeadAttention(num_heads=self.num_heads,
                               dtype=self.dtype,
                               head_dim=self.head_dim,
                               transpose_batch_sequence=self.transpose_batch_sequence,
                               dropout_rate=self.dropout_rate,
                               float32_logits=self.float32_attention_logits,
                               scale_attn_logits=self.scale_attn_logits,
                               scaled_query_init=self.scaled_query_init,
                               fuse_qkv=self.fuse_qkv_params,
                               name='encoder_decoder_attention')(y,
                                                                 encoded,
                                                                 encoder_decoder_mask,
                                                                 deterministic=deterministic)
        y = nn.Dropout(rate=self.dropout_rate,
                       broadcast_dims=(sequence_dim,))(y, deterministic=deterministic)
        y = y + residual

        # MLP block.
        residual = y
        z = LayerNorm(layernorm_type=self.layernorm_type,
980
                      zero_centered_gamma=self.zero_centered_gamma,
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
                      dtype=self.dtype,
                      name='pre_mlp_layer_norm')(y)
        if self.apply_residual_connection_post_layernorm:
            residual = z
        z = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
            intermediate_dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
            name='mlp',
        )(z, deterministic=deterministic)
        z = nn.Dropout(rate=self.dropout_rate,
                       broadcast_dims=(sequence_dim,))(z, deterministic=deterministic)
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
            z = nn.Dropout(rate=self.drop_path,
                           broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
        z = z + residual

        if self.output_layernorm:
            z = LayerNorm(layernorm_type=self.layernorm_type,
1004
                          zero_centered_gamma=self.zero_centered_gamma,
1005
1006
1007
1008
1009
1010
                          dtype=self.dtype,
                          name="output_layer_norm")(z)

        return z


1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
def assert_allclose(actual,
                    desired,
                    rtol=1e-05,
                    atol=1e-08,
                    equal_nan=True,
                    err_msg='',
                    verbose=True):
    if not isinstance(actual, float):
        actual = actual.astype(jnp.float32)
    if not isinstance(desired, float):
        desired = desired.astype(jnp.float32)
    np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)