utils.py 58.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
zlsh80826's avatar
zlsh80826 committed
4
"""Utility for the TE layer tests"""
5

6
import functools
7
import math
8
import operator
9
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
10
import os
11

12
import jax
13
14
import jax.numpy as jnp
import numpy as np
15
16
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
17
from flax.linen.attention import combine_masks
18
19
20
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
21

22
23
24
25
26
from transformer_engine.jax.attention import (
    AttnMaskType,
    canonicalize_attn_mask_type,
    make_swa_mask,
)
27
28
from transformer_engine.jax.fp8 import DType as TEDType

29
30
31
32
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
33
34
35
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
36
37
Initializer = Callable[[PRNGKey, Shape, DType], Array]

38
39
40
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))

41

42
def is_devices_enough(required):
zlsh80826's avatar
zlsh80826 committed
43
44
45
    """
    Check if the available GPUs is enough
    """
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    return len(jax.devices()) >= required


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
69
    if fn_or_string == "linear":
70
71
72
73
74
75
76
77
78
79
80
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases.

81
82
    Args:
      *masks: set of attention bias arguments to combine, some can be None.
83

84
85
86
    Returns:
      Combined mask, reduced by summation, returns None if no masks given.
    """
87
88
89
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
90
91
92
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
93
94
95
96
97
98
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


99
100
101
class DotProductAttention(nn.Module):
    transpose_batch_sequence: bool = True
    scale_attn_logits: bool = True
102
    dropout_rate: float = 0.0
103
104
    dtype: DType = jnp.float32
    float32_logits: bool = False
105
106
    """Computes dot-product attention given query, key, and value.

107
108
109
    This is the core function for applying attention based on
    https://arxiv.org/abs/1706.03762. It calculates the attention weights given
    query and key and combines the values using the attention weights.
110

111
112
    Args:
        dropout_rate: dropout rate
113
        dtype: the data type used to allocate the initial parameters (default: float32).
114
115
116
        float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
    """
117

118
    @nn.compact
119
120
121
122
123
124
125
126
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        bias: Optional[Array] = None,
        deterministic: bool = False,
    ):
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        """
        Args:
            query: queries for calculating attention with shape of `[batch, q_length,
            num_heads, qk_depth_per_head]`.
            key: keys for calculating attention with shape of `[batch, kv_length,
            num_gqa_groups, qk_depth_per_head]`.
            value: values to be used in attention with shape of `[batch, kv_length,
            num_gqa_groups, v_depth_per_head]`.
            bias: bias for the attention weights. This should be broadcastable to the
            shape `[batch, num_heads, q_length, kv_length]` This can be used for
            incorporating causal masks, padding masks, proximity bias, etc.
            dropout_rng: JAX PRNGKey: to be used for dropout
            deterministic: bool, deterministic or not (to apply dropout)
        Returns:
            Output of shape `[batch, length, num_heads, v_depth_per_head]`.
        """
143
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
144
        batch_dim = 1 if self.transpose_batch_sequence else 0
145
146
147
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
148
        sequence_dim = 0 if self.transpose_batch_sequence else 1
149
150
151
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        if self.scale_attn_logits:
            head_dim = query.shape[-1]
            depth_scaling = jnp.sqrt(head_dim).astype(self.dtype)
            query = query / depth_scaling

        # Casting logits and softmax computation for float32 for model stability.
        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)

        # `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
        h_q, h_kv = query.shape[-2], key.shape[-2]
        assert (h_q % h_kv == 0) and (h_q >= h_kv)
        group_size = h_q // h_kv
        grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
170
            attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
171
        else:
172
            attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        # reshape back to normal DPA shape for bias/softmax/dropout
        b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
        attn_weights_without_groups_shape = (b, h * g, q, k)
        attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        # Apply attention bias: masking, dropout, proximity bias, etc.
        if bias is not None:
            attn_weights = attn_weights + bias.astype(attn_weights.dtype)

        # Normalize the attention weights across `kv_length` dimension.
        attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)

        # Apply attention dropout.
187
        if not deterministic and self.dropout_rate > 0.0:
188
189
190
191
            keep_prob = 1.0 - self.dropout_rate
            # T5 broadcasts along the "length" dim, but unclear which one that
            # corresponds to in positional dimensions here, assuming query dim.
            dropout_shape = list(attn_weights.shape)
192
            dropout_rng = self.make_rng("dropout")
193
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
194
            multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
195
196
197
            attn_weights = attn_weights * multiplier

        attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
198
        attn_weights = attn_weights.astype(value.dtype)
199
200
201

        # Take the linear combination of `value`.
        if self.transpose_batch_sequence:
202
            return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
203

204
        return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
205
206
207
208
209


class DenseGeneral(nn.Module):
    """A linear transformation with flexible axes and FP8 support.

210
211
212
    Attributes:
    features: tuple with numbers of output features.
    axis: tuple with axes to apply the transformation on.
213
    dtype: the data type used to allocate the initial parameters (default: float32).
214
215
216
    kernel_init: initializer function for the weight matrix.
    use_bias: whether to add a bias to the output (default: False).
    bias_init: initializer function for the bias vector.
217
    """
218

219
220
221
222
223
224
225
226
227
228
229
    features: Union[Iterable[int], int]
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()

    def __post_init__(self):
        if self.kernel_init is None:
230
231
232
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
233
234
235
236
237
238
239
240
241
242
243
244
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

        Args:
        inputs: The nd-array to be transformed.

        Returns:
        The transformed input.
        """
245
        input_dtype = inputs.dtype
246
247
248
249
250
251
252
253
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
254
        kernel = nn_partitioning.param_with_axes(
255
            "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
256
        )
257

258
        kernel = jnp.asarray(kernel, input_dtype)
259
260
261
        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
262
            bias = nn_partitioning.param_with_axes(
263
                "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
264
            )
265
            bias = bias.astype(input_dtype)
266
267
268
269
270
271
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))

        y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
272
        y = y.astype(input_dtype)
273
274
275
276
277
278
279
280
281

        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block.

282
283
284
285
286
287
288
    Attributes:
      intermediate_dim: Shared dimension of hidden layers.
      activations: Type of activations for each layer.  Each element is either
        'linear', a string function name in flax.linen, or a function.
      kernel_init: Kernel function, passed to the dense layers.
      deterministic: Whether the dropout layers should be deterministic.
      intermediate_dropout_rate: Dropout rate used after the intermediate layers.
289
      dtype: the data type used to allocate the initial parameters (default: float32).
290
291
    """

292
293
    transpose_batch_sequence: bool
    intermediate_dim: int = 2048
294
    activations: Sequence[Union[str, Callable]] = ("relu",)
295
296
    kernel_init: Initializer = None
    intermediate_dropout_rate: float = 0.1
297
298
    intermediate_dropout_dims: Sequence[int] = ()
    use_bias: bool = False
299
    dtype: Any = jnp.float32
300
    fuse_wi: bool = True
301
302
303

    def __post_init__(self):
        if self.kernel_init is None:
304
305
306
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
307
308
309
310
311
312
313
314
315
316
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs, deterministic: bool = False):
        """Applies Transformer MlpBlock module."""
        # Iterate over specified MLP input activation functions.
        # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.

        activations = []
        if self.fuse_wi:
317
            dense_name = "wi"
318
            num_activations = len(self.activations)
319
320
321
322
323
324
325
326
327
            x = DenseGeneral(
                self.intermediate_dim * num_activations,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                kernel_axes=("embed", "mlp"),
                use_bias=self.use_bias,
                bias_axes="mlp",
                name=dense_name,
            )(inputs)
328
329
330
331
332
333
            x = jnp.split(x, num_activations, axis=-1)
            for idx, act_fn in enumerate(self.activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                activations.append(x_i)
        else:
            for idx, act_fn in enumerate(self.activations):
334
335
336
337
338
339
340
341
342
343
                dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
                x = DenseGeneral(
                    self.intermediate_dim,
                    dtype=self.dtype,
                    kernel_init=self.kernel_init,
                    kernel_axes=("embed", "mlp"),
                    use_bias=self.use_bias,
                    bias_axes="mlp",
                    name=dense_name,
                )(inputs)
344
345
346
347
348
349
                x = _convert_to_activation_function(act_fn)(x)
                activations.append(x)

        # Take elementwise product of above intermediate activations.
        x = functools.reduce(operator.mul, activations)
        # Apply dropout and final dense output projection.
350
351
352
353
354
        x = nn.Dropout(
            rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
        )(
            x, deterministic=deterministic
        )  # Broadcast along length.
355
        if self.transpose_batch_sequence:
356
            x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
357
        else:
358
359
360
361
362
363
364
365
366
367
            x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
        output = DenseGeneral(
            inputs.shape[-1],
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            kernel_axes=("mlp", "embed"),
            use_bias=self.use_bias,
            bias_axes="embed",
            name="wo",
        )(x)
368
369
370
        assert (
            output.dtype == inputs.dtype
        ), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
371
372
373
        return output


374
def apply_rotary_pos_emb_alternate(
375
376
377
378
379
380
381
382
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
383
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
384
385
386
387
388
389
390
391
392
393
394
395
396
    timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    first_half, second_half = jnp.split(inputs, 2, axis=-1)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    first_part = first_part.astype(inputs.dtype)
    second_part = second_part.astype(inputs.dtype)
    return jnp.concatenate([first_part, second_part], axis=-1)


397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def apply_rotary_pos_emb_consecutive(
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim

    inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
    inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1)
    inputs_shifted = jax.lax.select(
        jnp.tile(
            jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2),
            inputs.shape[:-1] + (1,),
        ),
        inputs_shifted_right,
        inputs_shifted_left,
    )
    fraction = jnp.repeat(fraction, 2)
418
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
419
420
421
422
423
424
425
426
427
428
429
430

    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))

    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
    outputs = inputs * cos + inputs_shifted * sin * sign

    return outputs


431
432
433
434
435
436
437
438
439
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class MultiHeadAttention(nn.Module):
    """Multi-head dot-product attention.

    Attributes:
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
zlsh80826's avatar
zlsh80826 committed
440
      num_gqa_groups: number of kv attention heads
441
      head_dim: dimension of each head.
442
      dtype: the data type used to allocate the initial parameters (default: float32).
443
444
445
446
      dropout_rate: dropout rate
      kernel_init: initializer for the kernel of the Dense layers.
      float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
447
    """
448

zlsh80826's avatar
zlsh80826 committed
449
450
451
452
    num_heads: int = 8
    num_gqa_groups: int | None = None
    head_dim: int = 64
    transpose_batch_sequence: bool = True
453
    dtype: DType = jnp.float32
454
    dropout_rate: float = 0.0
455
    kernel_init: Initializer = None
456
    float32_logits: bool = False  # computes logits in float32 for stability.
457
458
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
459
    enable_rotary_pos_emb: bool = False
460
    rotary_pos_emb_group_method: str = "consecutive"
461
    fuse_qkv: bool = True
462
    use_bias: bool = False
463
464
465

    def __post_init__(self):
        if self.kernel_init is None:
466
467
468
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "normal", dtype=self.dtype
            )
zlsh80826's avatar
zlsh80826 committed
469
470
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
471
472
473
        super().__post_init__()

    @nn.compact
474
475
476
477
478
479
480
481
482
483
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
484
485
        """Applies multi-head dot product attention on the input data.

486
487
        Projects the inputs into multi-headed query, key, and value vectors,
        applies dot-product attention and project the results to an output vector.
488

489
490
491
492
493
494
        There are two modes: decoding and non-decoding (e.g., training). The mode is
        determined by `decode` argument. For decoding, this method is called twice,
        first to initialize the cache and then for an actual decoding process. The
        two calls are differentiated by the presence of 'cached_key' in the variable
        dict. In the cache initialization stage, the cache variables are initialized
        as zeros and will be filled in the subsequent decoding process.
495

496
497
498
499
        In the cache initialization call, `inputs_q` has a shape [batch, length,
        q_features] and `inputs_kv`: [batch, length, kv_features]. During the
        incremental decoding stage, query, key and value all have the shape [batch,
        1, qkv_features] corresponding to a single step.
500

501
502
503
504
505
506
507
        Args:
          inputs_q: input queries of shape `[batch, q_length, q_features]`.
          inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
          mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
          bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
          decode: Whether to prepare and use an autoregressive cache.
          deterministic: Disables dropout if set to True.
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        Returns:
          output of shape `[batch, length, q_features]`.
        """
        q_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_heads * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )

        kv_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_gqa_groups * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )
531
532
533
534
535

        # NOTE: T5 does not explicitly rescale the attention logits by
        #       1/sqrt(depth_kq)!  This is folded into the initializers of the
        #       linear transformations, which is equivalent under Adafactor
        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
536
537
538
        query_init = lambda *args: self.kernel_init(*args) / (
            depth_scaling if self.scaled_query_init else 1.0
        )
539
540
541
542
543
544
545
546
547
548
549
550

        # Project inputs_q to multi-headed q/k/v
        # dimensions are then [batch, length, num_heads, head_dim]

        def qkv_init(key, shape, dtype):
            assert shape[-1] % 3 == 0

            q_shape = (shape[0], shape[1] // 3)
            k_shape = (shape[0], shape[1] // 3)
            v_shape = (shape[0], shape[1] // 3)

            q_kernel = query_init(key, q_shape, dtype)
zlsh80826's avatar
zlsh80826 committed
551
552
            k_kernel = self.kernel_init(key, k_shape, dtype)
            v_kernel = self.kernel_init(key, v_shape, dtype)
553
554
555

            return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)

556
557
558
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
zlsh80826's avatar
zlsh80826 committed
559

560
        if self.fuse_qkv:
zlsh80826's avatar
zlsh80826 committed
561
            if is_qkvpack:
562
563
564
565
566
567
568
569
570
571
                qkv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_heads * self.head_dim * 3,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_kv)
572
                query, key, value = jnp.split(
573
574
575
576
                    qkv_proj,
                    [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
                    axis=-1,
                )
577
            else:
578
579
580
581
582
583
584
585
586
587
588
589
                query = q_projection(kernel_init=query_init, name="query")(inputs_q)

                kv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_gqa_groups * self.head_dim * 2,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=self.kernel_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
zlsh80826's avatar
zlsh80826 committed
590
                key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
591
        else:
592
593
594
            query = q_projection(kernel_init=query_init, name="query")(inputs_q)
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
595

596
597
598
599
        if self.enable_rotary_pos_emb:
            batch_dim = 1 if self.transpose_batch_sequence else 0
            seq_dim = 1 - batch_dim

600
601
            q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
            k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
602

603
            if self.rotary_pos_emb_group_method == "alternate":
604
605
606
607
                apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
            else:
                apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive

608
609
610
611
            query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            query = apply_rotary_pos_emb(query, q_position)
            key = apply_rotary_pos_emb(key, k_position)
612

zlsh80826's avatar
zlsh80826 committed
613
614
615
        query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
        key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
        value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
616
617

        if self.transpose_batch_sequence:
618
619
620
621
622
623
624
            query = nn_partitioning.with_sharding_constraint(
                query, ("length", "batch", "heads", "kv")
            )
            key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
            value = nn_partitioning.with_sharding_constraint(
                value, ("length", "batch", "heads", "kv")
            )
625
        else:
626
627
628
629
630
631
632
            query = nn_partitioning.with_sharding_constraint(
                query, ("batch", "length", "heads", "kv")
            )
            key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
            value = nn_partitioning.with_sharding_constraint(
                value, ("batch", "length", "heads", "kv")
            )
633
634
635

        if decode:
            # Detect if we're initializing by absence of existing cache data.
636
            is_initialized = self.has_variable("cache", "cached_key")
637
638
639
640
641
            # The key and value have dimension [batch, length, num_heads, head_dim],
            # but we cache them as [batch, num_heads, head_dim, length] as a TPU
            # fusion optimization. This also enables the "scatter via one-hot
            # broadcast" trick, which means we do a one-hot broadcast instead of a
            # scatter/gather operations, resulting in a 3-4x speedup in practice.
zlsh80826's avatar
zlsh80826 committed
642
            swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
643
644
645
646
647
648
649
650
651
            cached_key = self.variable(
                "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
            )
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
652
653
654
655
656
657
658
659
            if is_initialized:
                batch, num_heads, head_dim, length = cached_key.value.shape
                # During fast autoregressive decoding, we feed one position at a time,
                # and cache the keys and values step by step.
                # Sanity shape check of cached key against input query.
                expected_shape = (batch, 1, num_heads, head_dim)
                if expected_shape != query.shape:
                    raise ValueError(
660
661
662
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689

                # Create a OHE of the current index. NOTE: the index is increased below.
                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
                # In order to update the key, value caches with the current key and
                # value, we move the length axis to the back, similar to what we did for
                # the cached ones above.
                # Note these are currently the key and value of a single position, since
                # we feed one position at a time.
                one_token_key = jnp.moveaxis(key, -3, -1)
                one_token_value = jnp.moveaxis(value, -3, -1)
                # Update key, value caches with our new 1d spatial slices.
                # We implement an efficient scatter into the cache via one-hot
                # broadcast and addition.
                key = cached_key.value + one_token_key * one_hot_indices
                value = cached_value.value + one_token_value * one_hot_indices
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # Move the keys and values back to their original shapes.
                key = jnp.moveaxis(key, -1, -3)
                value = jnp.moveaxis(value, -1, -3)

                # Causal mask for cached decoder self-attention: our single query
                # position should only attend to those key positions that have already
                # been generated and cached, not the remaining zero elements.
                mask = combine_masks(
690
                    jnp.logical_not(mask),
691
692
                    jnp.broadcast_to(
                        jnp.arange(length) <= cur_index,
693
694
695
696
697
698
699
                        # (1, 1, length) represent (head dim, query length, key length)
                        # query length is 1 because during decoding we deal with one
                        # index.
                        # The same mask is applied to all batch elements and heads.
                        (batch, 1, 1, length),
                    ),
                )
700
701
702
703
704
705
706

                # Grab the correct relative attention bias during decoding. This is
                # only required during single step decoding.
                if bias is not None:
                    # The bias is a full attention matrix, but during decoding we only
                    # have to take a slice of it.
                    # This is equivalent to bias[..., cur_index:cur_index+1, :].
707
708
709
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
710
711
712
713

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
714
715
716
717
718
            attention_bias = lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.0).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype),
            )
719
720
721
722
723
724
725
726
        else:
            attention_bias = None

        # Add provided bias term (e.g. relative position embedding).
        if bias is not None:
            attention_bias = combine_biases(attention_bias, bias)

        # Apply attention.
727
728
729
730
731
732
733
        x = DotProductAttention(
            transpose_batch_sequence=self.transpose_batch_sequence,
            scale_attn_logits=self.scale_attn_logits,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            float32_logits=self.float32_logits,
        )(query, key, value, bias=attention_bias, deterministic=deterministic)
734
735
736
737

        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        if self.transpose_batch_sequence:
738
            x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
739
        else:
740
            x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
741
742
743

        # Back to the original inputs dimensions.
        out = DenseGeneral(
744
            features=inputs_q.shape[-1],  # output dim is set to the input dim.
745
746
            axis=-1,
            kernel_init=self.kernel_init,
747
            kernel_axes=("joined_kv", "embed"),
748
            use_bias=self.use_bias,
749
            bias_axes="embed",
750
            dtype=self.dtype,
751
752
            name="out",
        )(x)
753
754
755
        assert (
            inputs_q.dtype == inputs_kv.dtype == out.dtype
        ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
756
757
758
759
760
        return out


class LayerNorm(nn.Module):
    """T5 Layer normalization operating on the last axis of the input data."""
761

762
763
    epsilon: float = 1e-6
    dtype: Any = jnp.float32
764
    layernorm_type: str = "layernorm"
765
766
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
767
768
    bias_init: Initializer = nn.initializers.zeros

769
770
771
772
773
774
775
776
    def __post_init__(self):
        if self.scale_init is None:
            if not self.zero_centered_gamma:
                self.scale_init = nn.initializers.ones
            else:
                self.scale_init = nn.initializers.zeros
        super().__post_init__()

777
778
779
780
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Applies layer normalization on the input."""

781
        input_dtype = x.dtype
782
783
        features = x.shape[-1]

784
        scale = nn_partitioning.param_with_axes(
785
            "scale", self.scale_init, (features,), self.dtype, axes=("embed",)
786
        )
787
        scale = jnp.asarray(scale, input_dtype)
788

789
        if self.layernorm_type == "layernorm":
790
791
792
793
            mean = jnp.mean(x, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
            y = (x - mean) * lax.rsqrt(var + self.epsilon)

794
            bias = nn_partitioning.param_with_axes(
795
                "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
796
            )
797
            bias = jnp.asarray(bias, input_dtype)
798

799
800
801
            if not self.zero_centered_gamma:
                z = y * scale + bias
            else:
802
                z = y * (scale + 1.0) + bias
803
        else:
804
            assert self.layernorm_type == "rmsnorm"
805
            assert not self.zero_centered_gamma
806
            mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
807
            y = x * lax.rsqrt(mean2 + self.epsilon)
808
809
            z = y * scale

810
811
        assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
        return z
812
813
814
815
816


class RelativePositionBiases(nn.Module):
    """Adds T5-style relative positional embeddings to the attention logits.

817
818
819
820
821
822
823
    Attributes:
      num_buckets: Number of buckets to bucket distances between key and query
        positions into.
      max_distance: Maximum distance before everything is lumped into the last
        distance bucket.
      num_heads: Number of heads in the attention layer. Each head will get a
        different relative position weighting.
824
      dtype: the data type used to allocate the initial parameters (default: float32).
825
826
827
      embedding_init: initializer for relative embedding table.
    """

828
829
830
831
832
833
834
    num_buckets: int
    max_distance: int
    num_heads: int
    dtype: Any
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init

    @staticmethod
835
836
837
    def _relative_position_bucket(
        relative_position, bidirectional=True, num_buckets=32, max_distance=128
    ):
838
839
        """Translate relative position to a bucket number for relative attention.

840
841
842
843
844
845
846
847
848
849
        The relative position is defined as memory_position - query_position, i.e.
        the distance in tokens from the attending position to the attended-to
        position.  If bidirectional=False, then positive relative positions are
        invalid.
        We use smaller buckets for small absolute relative_position and larger
        buckets for larger absolute relative_positions.  All relative
        positions >=max_distance  map to the same bucket.  All relative
        positions <=-max_distance map to the same bucket.  This should allow for
        more graceful generalization to longer sequences than the model has been
        trained on.
850

851
852
853
854
855
        Args:
          relative_position: an int32 array
          bidirectional: a boolean - whether the attention is bidirectional
          num_buckets: an integer
          max_distance: an integer
856

857
858
859
860
        Returns:
          a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
861
862
863
864
865
866
867
868
869
870
871
872
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).astype(np.int32) * num_buckets
            n = np.abs(n)
        else:
            n = np.maximum(n, 0)
        # now n is in the range [0, inf)
        max_exact = num_buckets // 2
        is_small = n < max_exact
        val_if_large = max_exact + (
873
874
875
876
            np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
            / np.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).astype(np.int32)
877
878
879
880
881
882
883
884
        val_if_large = np.minimum(val_if_large, num_buckets - 1)
        ret += np.where(is_small, n, val_if_large)
        return ret

    @nn.compact
    def __call__(self, qlen, klen, bidirectional=True):
        """Produce relative position embedding attention biases.

885
886
887
888
889
        Args:
          qlen: attention query length.
          klen: attention key length.
          bidirectional: whether to allow positive memory-query relative position
            embeddings.
890

891
892
893
        Returns:
          output: `(1, len, q_len, k_len)` attention bias
        """
894
895
        context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
896
897
898
899
900
901
902
        relative_position = memory_position - context_position  # shape (qlen, klen)
        rp_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
        )
903
        relative_attention_bias = nn_partitioning.param_with_axes(
904
905
906
            "rel_embedding",
            self.embedding_init,
            (self.num_heads, self.num_buckets),
907
            jnp.float32,
908
909
            axes=("heads", "relpos_buckets"),
        )
910
911
912
913
914
915
916
917
918
919
920
921
922

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
        # Instead of using a slow gather, we create a leading-dimension one-hot
        # array from rp_bucket and use it to perform the gather-equivalent via a
        # contraction, i.e.:
        # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
        # This is equivalent to relative_attention_bias[:, rp_bucket]
        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
        # --> shape (qlen, klen, num_heads)
        values = lax.dot_general(
            relative_attention_bias,
            rp_bucket_one_hot,
923
924
            (((1,), (0,)), ((), ())),  # rhs, lhs contracting dims
        )  # no batched dims
925
926
927
928
929
        # Add a singleton batch dimension.
        # --> shape (1, num_heads, qlen, klen)
        return values[jnp.newaxis, ...]


930
931
932
933
934
935
def apply_swa_mask(
    attn_mask_type: str,
    original_mask: Array,
    window_size: Tuple[int, int] = (-1, -1),
) -> Array:
    """Apply the sliding window mask to a given mask"""
936
    _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
937
    assert _attn_mask_type is not None
938
    batch = original_mask.shape[0]
939
940
    max_seqlen_q = original_mask.shape[-2]
    max_seqlen_kv = original_mask.shape[-1]
941
942
943
    pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
    pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
    swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
944
    # In swa_mask and original_mask 0 is masked out
945
    new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
946
947
948
    return new_mask


949
950
class EncoderLayer(nn.Module):
    """Transformer encoder layer."""
951

952
    enable_relative_embedding: bool = True
953
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
954
955
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
956
    head_dim: int = 64
957
958
959
960
961
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
962
963
964
965
966
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
967
    mlp_activations: Sequence[str] = ("relu",)
968
    use_bias: bool = False
969
970
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
971
    layernorm_type: str = "layernorm"
972
    layernorm_epsilon: float = 1e-6
973
    zero_centered_gamma: bool = False
974
975
    output_layernorm: bool = False
    drop_path: float = 0.0
976
    enable_rotary_pos_emb: bool = False
977
    rotary_pos_emb_group_method: str = "consecutive"
978
    fuse_qkv_params: bool = True
979
980
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
981
982
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
983

zlsh80826's avatar
zlsh80826 committed
984
985
986
987
988
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

989
990
    @nn.compact
    def __call__(self, inputs, encoder_mask=None, deterministic=False):
991
992
993
994
995
996
997
        # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
        encoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            encoder_mask,
            self.window_size,
        )

998
999
1000
1001
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1002
1003
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1004
1005
1006
1007
1008
1009
1010
1011
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1012
1013
1014
            else:
                rel_emb = self.relative_embedding
            encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
1015
        else:
1016
            encoder_bias = None
1017
1018
1019
1020
1021
1022

        # Attention block.
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1023
1024
1025
1026
1027
1028
1029
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_attention_layer_norm",
            )(inputs)
1030
1031
1032
1033
1034
1035
1036

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            fuse_qkv=self.fuse_qkv_params,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            use_bias=self.use_bias,
            name="attention",
        )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1056
1057
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1058
1059
1060
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1061
1062
1063
1064
        x = x + residual

        # MLP block.
        residual = x
1065
1066
1067
1068
1069
1070
1071
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(x)
1072
1073
1074
1075
1076
1077
1078
1079
1080

        if self.apply_residual_connection_post_layernorm:
            residual = y

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        y = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1081
1082
1083
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1084
1085
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1086
            name="mlp",
1087
        )(y, deterministic=deterministic)
1088
1089
1090
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1091
1092
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
1093
1094
1095
            y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                y, deterministic=deterministic
            )
1096
1097
1098
        y = y + residual

        if self.output_layernorm:
1099
1100
1101
1102
1103
1104
1105
            y = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(y)
1106
        assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
1107
1108
1109
1110
1111
        return y


class DecoderLayer(nn.Module):
    """Transformer decoder layer that attends to the encoder."""
1112

1113
    enable_relative_embedding: bool = True
1114
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1115
1116
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1117
    head_dim: int = 64
1118
1119
1120
1121
1122
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1123
1124
1125
1126
1127
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1128
    mlp_activations: Sequence[str] = ("relu",)
1129
    use_bias: bool = False
1130
1131
1132
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
1133
    layernorm_type: str = "layernorm"
1134
    layernorm_epsilon: float = 1e-6
1135
    zero_centered_gamma: bool = False
1136
    drop_path: float = 0.0
1137
    enable_rotary_pos_emb: bool = False
1138
    rotary_pos_emb_group_method: str = "consecutive"
1139
    fuse_qkv_params: bool = True
1140
1141
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1142
1143
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1144

zlsh80826's avatar
zlsh80826 committed
1145
1146
1147
1148
1149
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1150
    @nn.compact
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    def __call__(
        self,
        inputs,
        encoded,
        decoder_mask=None,
        encoder_decoder_mask=None,
        deterministic=False,
        decode=False,
        max_decode_length=None,
    ):
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        decoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            decoder_mask,
            self.window_size,
        )

        encoder_decoder_mask = apply_swa_mask(
            "padding",
            encoder_decoder_mask,
            self.window_size,
        )

1173
1174
1175
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim
1176
1177
1178
1179

        if self.enable_relative_embedding:
            l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
            if self.relative_embedding is None:
1180
1181
1182
1183
1184
1185
1186
1187
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1188
1189
1190
            else:
                rel_emb = self.relative_embedding
            decoder_bias = rel_emb(l, l, False)
1191
        else:
1192
            decoder_bias = None
1193
1194
1195
1196
1197
1198

        # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1199
1200
1201
1202
1203
1204
1205
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_self_attention_layer_norm",
            )(inputs)
1206
1207
1208
1209
1210
1211
1212

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # Self-attention block
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="self_attention",
        )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1232
1233
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1234
1235
1236
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1237
1238
1239
1240
        x = x + residual

        # Encoder-Decoder block.
        residual = x
1241
1242
1243
1244
1245
1246
1247
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_cross_attention_layer_norm",
        )(x)
1248
1249
1250

        if self.apply_residual_connection_post_layernorm:
            residual = y
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
        y = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="encoder_decoder_attention",
        )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1270
1271
1272
1273
        y = y + residual

        # MLP block.
        residual = y
1274
1275
1276
1277
1278
1279
1280
        z = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(y)
1281
1282
1283
1284
1285
1286
        if self.apply_residual_connection_post_layernorm:
            residual = z
        z = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1287
1288
1289
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1290
1291
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1292
            name="mlp",
1293
        )(z, deterministic=deterministic)
1294
1295
1296
        z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            z, deterministic=deterministic
        )
1297
1298
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
1299
1300
1301
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
1302
1303
1304
        z = z + residual

        if self.output_layernorm:
1305
1306
1307
1308
1309
1310
1311
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
1312

1313
        assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
1314
1315
1316
        return z


1317
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1318
1319
1320
    """
    Generate causal mask
    """
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
    shape = (batch, seqlen)
    idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)

    mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


def make_self_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1331
1332
1333
    """
    Generate attention mask
    """
1334
1335
1336
1337
1338
1339
1340
    shape = (batch, seqlen)
    mask = jnp.ones((*shape, shape[-1]))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
def assert_allclose(
    actual: Array,
    desired: Array,
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
    dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
    **kwargs,
) -> None:
    """Check if two tensors are close.

    Args:
      actual: test tensor.
      desired: reference tensor.
      dtype: data type or data type name (default: inferred from
        `actual`).
      rtol: relative tolerance (default: based on `dtype`).
      atol: absolute tolerance (default: based on `dtype`).
      **kwargs: keyword arguments to pass to np.testing.assert_allclose.
    """

    # Infer data type if needed
    if dtype is None:
        if isinstance(actual, float):
            dtype = "float32"
        else:
            dtype = actual.dtype

    # Determine tolerances
zlsh80826's avatar
zlsh80826 committed
1369
    tols = {}
1370
1371
1372
1373
1374
1375
1376
1377
    if rtol is None or atol is None:
        tols = dtype_tols(dtype)
    if rtol is not None:
        tols["rtol"] = rtol
    if atol is not None:
        tols["atol"] = atol

    # Cast tensors to fp32
1378
1379
1380
1381
    if not isinstance(actual, float):
        actual = actual.astype(jnp.float32)
    if not isinstance(desired, float):
        desired = desired.astype(jnp.float32)
1382
1383
1384
1385
1386

    # Check if tensors are close
    np.testing.assert_allclose(actual, desired, **tols, **kwargs)


1387
1388
1389
1390
def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
    flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
    flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)

1391
1392
1393
    for (expected_path, expected_value), (actual_path, actual_value) in zip(
        flatten_expected, flatten_actual
    ):
1394
1395
        assert expected_path == actual_path
        key_str = jax.tree_util.keystr(expected_path)
1396
1397
1398
1399
1400
1401
1402
        assert_allclose(
            expected_value,
            actual_value,
            rtol=rtol,
            atol=atol,
            err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
        )
1403
1404


1405
1406
1407
def dtype_tols(
    dtype: Union[DType, TEDType, np.dtype],
    reference_value: float = 1.0,
1408
1409
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
1410
1411
1412
1413
1414
1415
) -> Dict[str, float]:
    """Expected numerical tolerance for a data type.

    Args:
      dtype: data type.
      reference_value: reference value (default: 1).
1416
1417
      rtol: override for relative tolerance estimate
      atol: override for absolute tolerance estimate
1418
1419
1420
1421
1422
1423

    Returns:
      Dictionary with "rtol" and "atol" as keys

    """

1424
1425
1426
1427
    # Return immediately if tolerances are fully specified
    if rtol is not None and atol is not None:
        return {"rtol": rtol, "atol": atol}

1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
    # Convert to JAX dtype if needed
    if isinstance(dtype, TEDType):
        dtype = {
            TEDType.kByte: jnp.uint8,
            TEDType.kInt32: jnp.int32,
            TEDType.kInt64: jnp.int64,
            TEDType.kFloat32: jnp.float32,
            TEDType.kFloat16: jnp.float16,
            TEDType.kBFloat16: jnp.bfloat16,
            TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
            TEDType.kFloat8E5M2: jnp.float8_e5m2,
        }[dtype]
    elif isinstance(dtype, np.dtype):
        dtype = jnp.dtype(dtype)

    # Expect bit-wise accuracy for integer dtypes
    if not jnp.issubdtype(dtype, jnp.floating):
1445
1446
1447
1448
1449
        if rtol is None:
            rtol = 0.0
        if atol is None:
            atol = 0.0
        return {"rtol": rtol, "atol": atol}
1450
1451
1452

    # Estimate floating-point error
    finfo = jnp.finfo(dtype)
1453
    eps_relaxed = math.pow(finfo.eps, 2 / 3)
1454
1455
1456
1457
1458
1459
1460
1461
    with jax.default_device(jax.devices("cpu")[0]):
        if isinstance(reference_value, (float, int)):
            reference_value = jnp.array(reference_value, dtype=dtype)
        else:
            reference_value = reference_value.astype(dtype)
        spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
        spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
        ulp = max(spacing_high.item(), spacing_low.item())
1462
1463
1464
1465
1466
    if rtol is None:
        rtol = eps_relaxed
    if atol is None:
        atol = max(ulp, eps_relaxed)
    return {"rtol": rtol, "atol": atol}
1467
1468


1469
def sync_params_values(dst, src, transformations, sep="/"):
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
    """
    This function will reconstuct a tree with dst's tree_def/shape and src's value.
    transformations is a map that records the key mappings between dst and src.
    If no dst key found in the transformerations, it will fall back to src key = dst key.
    transformations = {
        dst key map 0: src key map 0,
        dst key map 1: src key map 1,
        ...
        # if dst key = src key, we don't need to add it
    }
    """
    src_values = {}
    for key, value in jax.tree_util.tree_leaves_with_path(src):
        normalized_key = sep.join(x.key for x in key)
        src_values[normalized_key] = value

    flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
    synced_dst_values = []

    for key, value in flatten_dst:
        normalized_key = sep.join(x.key for x in key)
        if normalized_key in transformations:
            corresponding_src_key = transformations[normalized_key]
        else:
            corresponding_src_key = normalized_key
        synced_dst_values.append(src_values[corresponding_src_key])

    synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)

    return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519


@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
    if NVTE_DEBUG_NUMERICS:
        args = [
            jnp.mean(tensor),
            jnp.min(tensor),
            jnp.max(tensor),
            jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
            jnp.count_nonzero(tensor),
        ]
        fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"

        if hist:
            h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
            args += [h[0], h[1]]
            fmt = fmt + "\n  {}\n  {}"

        jax.debug.print(fmt, *args)