utils.py 60.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
zlsh80826's avatar
zlsh80826 committed
4
"""Utility for the TE layer tests"""
5

6
import functools
7
import math
8
import operator
9
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
10
import os
11

12
import jax
13
14
import jax.numpy as jnp
import numpy as np
15
16
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
17
from flax.linen.attention import combine_masks
18
19
20
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
21
import pytest
22

23
24
25
26
27
from transformer_engine.jax.attention import (
    AttnMaskType,
    canonicalize_attn_mask_type,
    make_swa_mask,
)
28
from transformer_engine.jax.quantize.helper import DType as TEDType
29

30
31
32
33
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
34
35
36
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
37
38
Initializer = Callable[[PRNGKey, Shape, DType], Array]

39
40
41
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))

42

43
def is_devices_enough(required):
zlsh80826's avatar
zlsh80826 committed
44
45
46
    """
    Check if the available GPUs is enough
    """
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    return len(jax.devices()) >= required


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
70
    if fn_or_string == "linear":
71
72
73
74
75
76
77
78
79
80
81
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases.

82
83
    Args:
      *masks: set of attention bias arguments to combine, some can be None.
84

85
86
87
    Returns:
      Combined mask, reduced by summation, returns None if no masks given.
    """
88
89
90
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
91
92
93
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
94
95
96
97
98
99
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""):
    """
    Takes an input dictionary of parameters keyed by test type "L0", etc.
    Returns a list of pytest parameters to be used in a parameterized test for the current test type
    """
    DEFAULT_TEST_LEVEL = "L0"
    test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
    if test_level not in param_dict:
        raise ValueError("Unsupported test level")
    return values_to_named_params(param_dict[test_level], id_prefix)


def value_to_test_name_str(value):
    """Converts a value to how it should appear in a test name."""
    if isinstance(value, tuple) or isinstance(value, list):
        return "_".join([value_to_test_name_str(v) for v in value])

    dtype_type = type(jnp.float32)
    if isinstance(value, dtype_type):
        return value.dtype

    return str(value)


def value_to_named_param(value, id_prefix: str = ""):
    param_type = type(pytest.param(0))
    if isinstance(value, param_type):
        return value

    x = pytest.param(value, id=f"{id_prefix}_{value_to_test_name_str(value)}")
    return x


def values_to_named_params(params, id_prefix: str = ""):
    return [value_to_named_param(v, id_prefix=id_prefix) for v in params]


def pytest_parametrize_wrapper(param_name, param_values):
    """
    A wrapper for pytest.mark.parametrize to allow for automatic
    naming of tests based on the parameter values.
    """
    id_prefix = param_name
    if isinstance(param_values, dict):
        param_values = parameterize_by_test_level(param_values, id_prefix=param_name)
    elif "," not in param_name:
        param_values = values_to_named_params(param_values, id_prefix=id_prefix)

    # Currently comma separated parameters in one parametrize call aren't supported for automatic naming
    # and will just be passed through with default pytest names
    def decorator(func):
        return pytest.mark.parametrize(param_name, param_values)(func)

    return decorator


156
157
158
class DotProductAttention(nn.Module):
    transpose_batch_sequence: bool = True
    scale_attn_logits: bool = True
159
    dropout_rate: float = 0.0
160
161
    dtype: DType = jnp.float32
    float32_logits: bool = False
162
163
    """Computes dot-product attention given query, key, and value.

164
165
166
    This is the core function for applying attention based on
    https://arxiv.org/abs/1706.03762. It calculates the attention weights given
    query and key and combines the values using the attention weights.
167

168
169
    Args:
        dropout_rate: dropout rate
170
        dtype: the data type used to allocate the initial parameters (default: float32).
171
172
173
        float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
    """
174

175
    @nn.compact
176
177
178
179
180
181
182
183
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        bias: Optional[Array] = None,
        deterministic: bool = False,
    ):
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        """
        Args:
            query: queries for calculating attention with shape of `[batch, q_length,
            num_heads, qk_depth_per_head]`.
            key: keys for calculating attention with shape of `[batch, kv_length,
            num_gqa_groups, qk_depth_per_head]`.
            value: values to be used in attention with shape of `[batch, kv_length,
            num_gqa_groups, v_depth_per_head]`.
            bias: bias for the attention weights. This should be broadcastable to the
            shape `[batch, num_heads, q_length, kv_length]` This can be used for
            incorporating causal masks, padding masks, proximity bias, etc.
            dropout_rng: JAX PRNGKey: to be used for dropout
            deterministic: bool, deterministic or not (to apply dropout)
        Returns:
            Output of shape `[batch, length, num_heads, v_depth_per_head]`.
        """
200
        input_dtype = query.dtype
201
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
202
        batch_dim = 1 if self.transpose_batch_sequence else 0
203
204
205
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
206
        sequence_dim = 0 if self.transpose_batch_sequence else 1
207
208
209
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
210
211
212

        if self.scale_attn_logits:
            head_dim = query.shape[-1]
213
            depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            query = query / depth_scaling

        # Casting logits and softmax computation for float32 for model stability.
        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)

        # `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
        h_q, h_kv = query.shape[-2], key.shape[-2]
        assert (h_q % h_kv == 0) and (h_q >= h_kv)
        group_size = h_q // h_kv
        grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
228
            attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
229
        else:
230
            attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
231
232
233
234
235
236
237
238
239
240
241

        # reshape back to normal DPA shape for bias/softmax/dropout
        b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
        attn_weights_without_groups_shape = (b, h * g, q, k)
        attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        # Apply attention bias: masking, dropout, proximity bias, etc.
        if bias is not None:
            attn_weights = attn_weights + bias.astype(attn_weights.dtype)

        # Normalize the attention weights across `kv_length` dimension.
242
        attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
243
244

        # Apply attention dropout.
245
        if not deterministic and self.dropout_rate > 0.0:
246
247
248
249
            keep_prob = 1.0 - self.dropout_rate
            # T5 broadcasts along the "length" dim, but unclear which one that
            # corresponds to in positional dimensions here, assuming query dim.
            dropout_shape = list(attn_weights.shape)
250
            dropout_rng = self.make_rng("dropout")
251
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
252
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
253
254
255
            attn_weights = attn_weights * multiplier

        attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
256
        # attn_weights = attn_weights.astype(input_dtype)
257
258
259

        # Take the linear combination of `value`.
        if self.transpose_batch_sequence:
260
            return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
261

262
263
264
265
        assert (
            attn_weights.dtype == input_dtype
        ), f"input.dtype={input_dtype}, output.dtype={attn_weights.dtype}"

266
        return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
267
268
269
270
271


class DenseGeneral(nn.Module):
    """A linear transformation with flexible axes and FP8 support.

272
273
274
    Attributes:
    features: tuple with numbers of output features.
    axis: tuple with axes to apply the transformation on.
275
    dtype: the data type used to allocate the initial parameters (default: float32).
276
277
278
    kernel_init: initializer function for the weight matrix.
    use_bias: whether to add a bias to the output (default: False).
    bias_init: initializer function for the bias vector.
279
    """
280

281
282
283
284
285
286
287
288
289
290
291
    features: Union[Iterable[int], int]
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()

    def __post_init__(self):
        if self.kernel_init is None:
292
293
294
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
295
296
297
298
299
300
301
302
303
304
305
306
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

        Args:
        inputs: The nd-array to be transformed.

        Returns:
        The transformed input.
        """
307
        input_dtype = inputs.dtype
308
309
310
311
312
313
314
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
315
        kernel = nn_partitioning.param_with_axes(
316
            "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
317
        )
318

319
        kernel = jnp.asarray(kernel, input_dtype)
320
321
322
        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
323
            bias = nn_partitioning.param_with_axes(
324
                "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
325
            )
326
            bias = bias.astype(input_dtype)
327
328
329
330
331
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))

332
333
334
        y = lax.dot_general(
            inputs, kernel, ((axis, contract_ind), ((), ())), preferred_element_type=input_dtype
        )
335
336
337

        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
338
339

        assert y.dtype == inputs.dtype, f"input.dtype={inputs.dtype}, output.dtype={y.dtype}"
340
341
342
343
344
345
        return y


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block.

346
347
348
349
350
351
352
    Attributes:
      intermediate_dim: Shared dimension of hidden layers.
      activations: Type of activations for each layer.  Each element is either
        'linear', a string function name in flax.linen, or a function.
      kernel_init: Kernel function, passed to the dense layers.
      deterministic: Whether the dropout layers should be deterministic.
      intermediate_dropout_rate: Dropout rate used after the intermediate layers.
353
      dtype: the data type used to allocate the initial parameters (default: float32).
354
355
    """

356
357
    transpose_batch_sequence: bool
    intermediate_dim: int = 2048
358
    activations: Sequence[Union[str, Callable]] = ("relu",)
359
360
    kernel_init: Initializer = None
    intermediate_dropout_rate: float = 0.1
361
362
    intermediate_dropout_dims: Sequence[int] = ()
    use_bias: bool = False
363
    dtype: Any = jnp.float32
364
    fuse_wi: bool = True
365
366
367

    def __post_init__(self):
        if self.kernel_init is None:
368
369
370
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
371
372
373
374
375
376
377
378
379
380
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs, deterministic: bool = False):
        """Applies Transformer MlpBlock module."""
        # Iterate over specified MLP input activation functions.
        # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.

        activations = []
        if self.fuse_wi:
381
            dense_name = "wi"
382
            num_activations = len(self.activations)
383
384
385
386
387
388
389
390
391
            x = DenseGeneral(
                self.intermediate_dim * num_activations,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                kernel_axes=("embed", "mlp"),
                use_bias=self.use_bias,
                bias_axes="mlp",
                name=dense_name,
            )(inputs)
392
393
394
395
396
397
            x = jnp.split(x, num_activations, axis=-1)
            for idx, act_fn in enumerate(self.activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                activations.append(x_i)
        else:
            for idx, act_fn in enumerate(self.activations):
398
399
400
401
402
403
404
405
406
407
                dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
                x = DenseGeneral(
                    self.intermediate_dim,
                    dtype=self.dtype,
                    kernel_init=self.kernel_init,
                    kernel_axes=("embed", "mlp"),
                    use_bias=self.use_bias,
                    bias_axes="mlp",
                    name=dense_name,
                )(inputs)
408
409
410
411
412
413
                x = _convert_to_activation_function(act_fn)(x)
                activations.append(x)

        # Take elementwise product of above intermediate activations.
        x = functools.reduce(operator.mul, activations)
        # Apply dropout and final dense output projection.
414
415
416
417
418
        x = nn.Dropout(
            rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
        )(
            x, deterministic=deterministic
        )  # Broadcast along length.
419

420
        if self.transpose_batch_sequence:
421
            x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
422
        else:
423
424
425
426
427
428
429
430
431
432
            x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
        output = DenseGeneral(
            inputs.shape[-1],
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            kernel_axes=("mlp", "embed"),
            use_bias=self.use_bias,
            bias_axes="embed",
            name="wo",
        )(x)
433

434
435
436
        assert (
            output.dtype == inputs.dtype
        ), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
437
438
439
        return output


440
def apply_rotary_pos_emb_alternate(
441
442
443
444
445
446
447
448
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
449
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
450
451
452
453
454
455
456
457
458
459
    timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    first_half, second_half = jnp.split(inputs, 2, axis=-1)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    first_part = first_part.astype(inputs.dtype)
    second_part = second_part.astype(inputs.dtype)
460
    return jnp.concatenate([first_part, second_part], axis=-1).astype(inputs.dtype)
461
462


463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def apply_rotary_pos_emb_consecutive(
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim

    inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
    inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1)
    inputs_shifted = jax.lax.select(
        jnp.tile(
            jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2),
            inputs.shape[:-1] + (1,),
        ),
        inputs_shifted_right,
        inputs_shifted_left,
    )
    fraction = jnp.repeat(fraction, 2)
484
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
485
486
487
488
489
490
491
492
493

    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))

    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
    outputs = inputs * cos + inputs_shifted * sin * sign

494
    return outputs.astype(inputs.dtype)
495
496


497
498
499
500
501
502
503
504
505
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class MultiHeadAttention(nn.Module):
    """Multi-head dot-product attention.

    Attributes:
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
zlsh80826's avatar
zlsh80826 committed
506
      num_gqa_groups: number of kv attention heads
507
      head_dim: dimension of each head.
508
      dtype: the data type used to allocate the initial parameters (default: float32).
509
510
511
512
      dropout_rate: dropout rate
      kernel_init: initializer for the kernel of the Dense layers.
      float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
513
    """
514

zlsh80826's avatar
zlsh80826 committed
515
516
517
518
    num_heads: int = 8
    num_gqa_groups: int | None = None
    head_dim: int = 64
    transpose_batch_sequence: bool = True
519
    dtype: DType = jnp.float32
520
    dropout_rate: float = 0.0
521
    kernel_init: Initializer = None
522
    float32_logits: bool = False  # computes logits in float32 for stability.
523
524
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
525
    enable_rotary_pos_emb: bool = False
526
    rotary_pos_emb_group_method: str = "consecutive"
527
    fuse_qkv: bool = True
528
    use_bias: bool = False
529
530
531

    def __post_init__(self):
        if self.kernel_init is None:
532
533
534
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "normal", dtype=self.dtype
            )
zlsh80826's avatar
zlsh80826 committed
535
536
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
537
538
539
        super().__post_init__()

    @nn.compact
540
541
542
543
544
545
546
547
548
549
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
550
551
        """Applies multi-head dot product attention on the input data.

552
553
        Projects the inputs into multi-headed query, key, and value vectors,
        applies dot-product attention and project the results to an output vector.
554

555
556
557
558
559
560
        There are two modes: decoding and non-decoding (e.g., training). The mode is
        determined by `decode` argument. For decoding, this method is called twice,
        first to initialize the cache and then for an actual decoding process. The
        two calls are differentiated by the presence of 'cached_key' in the variable
        dict. In the cache initialization stage, the cache variables are initialized
        as zeros and will be filled in the subsequent decoding process.
561

562
563
564
565
        In the cache initialization call, `inputs_q` has a shape [batch, length,
        q_features] and `inputs_kv`: [batch, length, kv_features]. During the
        incremental decoding stage, query, key and value all have the shape [batch,
        1, qkv_features] corresponding to a single step.
566

567
568
569
570
571
572
573
        Args:
          inputs_q: input queries of shape `[batch, q_length, q_features]`.
          inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
          mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
          bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
          decode: Whether to prepare and use an autoregressive cache.
          deterministic: Disables dropout if set to True.
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        Returns:
          output of shape `[batch, length, q_features]`.
        """
        q_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_heads * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )

        kv_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_gqa_groups * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )
597
598
599
600
601

        # NOTE: T5 does not explicitly rescale the attention logits by
        #       1/sqrt(depth_kq)!  This is folded into the initializers of the
        #       linear transformations, which is equivalent under Adafactor
        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
602
603
604
        query_init = lambda *args: self.kernel_init(*args) / (
            depth_scaling if self.scaled_query_init else 1.0
        )
605
606
607
608
609
610
611
612
613
614
615
616

        # Project inputs_q to multi-headed q/k/v
        # dimensions are then [batch, length, num_heads, head_dim]

        def qkv_init(key, shape, dtype):
            assert shape[-1] % 3 == 0

            q_shape = (shape[0], shape[1] // 3)
            k_shape = (shape[0], shape[1] // 3)
            v_shape = (shape[0], shape[1] // 3)

            q_kernel = query_init(key, q_shape, dtype)
zlsh80826's avatar
zlsh80826 committed
617
618
            k_kernel = self.kernel_init(key, k_shape, dtype)
            v_kernel = self.kernel_init(key, v_shape, dtype)
619
620
621

            return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)

622
623
624
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
zlsh80826's avatar
zlsh80826 committed
625

626
        if self.fuse_qkv:
zlsh80826's avatar
zlsh80826 committed
627
            if is_qkvpack:
628

629
630
631
632
633
634
635
636
637
638
                qkv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_heads * self.head_dim * 3,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_kv)
639

640
                query, key, value = jnp.split(
641
642
643
644
                    qkv_proj,
                    [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
                    axis=-1,
                )
645

646
            else:
647
648
649
650
651
652
653
654
655
656
657
658
                query = q_projection(kernel_init=query_init, name="query")(inputs_q)

                kv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_gqa_groups * self.head_dim * 2,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=self.kernel_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
zlsh80826's avatar
zlsh80826 committed
659
                key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
660
        else:
661
662
663
            query = q_projection(kernel_init=query_init, name="query")(inputs_q)
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
664

665
666
667
668
        if self.enable_rotary_pos_emb:
            batch_dim = 1 if self.transpose_batch_sequence else 0
            seq_dim = 1 - batch_dim

669
670
            q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
            k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
671

672
            if self.rotary_pos_emb_group_method == "alternate":
673
674
675
676
                apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
            else:
                apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive

677
678
679
680
            query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            query = apply_rotary_pos_emb(query, q_position)
            key = apply_rotary_pos_emb(key, k_position)
681

zlsh80826's avatar
zlsh80826 committed
682
683
684
        query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
        key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
        value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
685
686

        if self.transpose_batch_sequence:
687
688
689
690
691
692
693
            query = nn_partitioning.with_sharding_constraint(
                query, ("length", "batch", "heads", "kv")
            )
            key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
            value = nn_partitioning.with_sharding_constraint(
                value, ("length", "batch", "heads", "kv")
            )
694
        else:
695
696
697
698
699
700
701
            query = nn_partitioning.with_sharding_constraint(
                query, ("batch", "length", "heads", "kv")
            )
            key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
            value = nn_partitioning.with_sharding_constraint(
                value, ("batch", "length", "heads", "kv")
            )
702
703
704

        if decode:
            # Detect if we're initializing by absence of existing cache data.
705
            is_initialized = self.has_variable("cache", "cached_key")
706
707
708
709
710
            # The key and value have dimension [batch, length, num_heads, head_dim],
            # but we cache them as [batch, num_heads, head_dim, length] as a TPU
            # fusion optimization. This also enables the "scatter via one-hot
            # broadcast" trick, which means we do a one-hot broadcast instead of a
            # scatter/gather operations, resulting in a 3-4x speedup in practice.
zlsh80826's avatar
zlsh80826 committed
711
            swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
712
713
714
715
716
717
718
719
720
            cached_key = self.variable(
                "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
            )
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
721
722
723
724
725
726
727
728
            if is_initialized:
                batch, num_heads, head_dim, length = cached_key.value.shape
                # During fast autoregressive decoding, we feed one position at a time,
                # and cache the keys and values step by step.
                # Sanity shape check of cached key against input query.
                expected_shape = (batch, 1, num_heads, head_dim)
                if expected_shape != query.shape:
                    raise ValueError(
729
730
731
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758

                # Create a OHE of the current index. NOTE: the index is increased below.
                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
                # In order to update the key, value caches with the current key and
                # value, we move the length axis to the back, similar to what we did for
                # the cached ones above.
                # Note these are currently the key and value of a single position, since
                # we feed one position at a time.
                one_token_key = jnp.moveaxis(key, -3, -1)
                one_token_value = jnp.moveaxis(value, -3, -1)
                # Update key, value caches with our new 1d spatial slices.
                # We implement an efficient scatter into the cache via one-hot
                # broadcast and addition.
                key = cached_key.value + one_token_key * one_hot_indices
                value = cached_value.value + one_token_value * one_hot_indices
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # Move the keys and values back to their original shapes.
                key = jnp.moveaxis(key, -1, -3)
                value = jnp.moveaxis(value, -1, -3)

                # Causal mask for cached decoder self-attention: our single query
                # position should only attend to those key positions that have already
                # been generated and cached, not the remaining zero elements.
                mask = combine_masks(
759
                    jnp.logical_not(mask),
760
761
                    jnp.broadcast_to(
                        jnp.arange(length) <= cur_index,
762
763
764
765
766
767
768
                        # (1, 1, length) represent (head dim, query length, key length)
                        # query length is 1 because during decoding we deal with one
                        # index.
                        # The same mask is applied to all batch elements and heads.
                        (batch, 1, 1, length),
                    ),
                )
769
770
771
772
773
774
775

                # Grab the correct relative attention bias during decoding. This is
                # only required during single step decoding.
                if bias is not None:
                    # The bias is a full attention matrix, but during decoding we only
                    # have to take a slice of it.
                    # This is equivalent to bias[..., cur_index:cur_index+1, :].
776
777
778
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
779
780
781
782

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
783

784
785
786
787
788
            attention_bias = lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.0).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype),
            )
789
790
791
792
793
794
795
796
        else:
            attention_bias = None

        # Add provided bias term (e.g. relative position embedding).
        if bias is not None:
            attention_bias = combine_biases(attention_bias, bias)

        # Apply attention.
797
798
799
800
801
802
803
        x = DotProductAttention(
            transpose_batch_sequence=self.transpose_batch_sequence,
            scale_attn_logits=self.scale_attn_logits,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            float32_logits=self.float32_logits,
        )(query, key, value, bias=attention_bias, deterministic=deterministic)
804
805
806
807

        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        if self.transpose_batch_sequence:
808
            x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
809
        else:
810
            x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
811
812

        # Back to the original inputs dimensions.
813

814
        out = DenseGeneral(
815
            features=inputs_q.shape[-1],  # output dim is set to the input dim.
816
817
            axis=-1,
            kernel_init=self.kernel_init,
818
            kernel_axes=("joined_kv", "embed"),
819
            use_bias=self.use_bias,
820
            bias_axes="embed",
821
            dtype=self.dtype,
822
823
            name="out",
        )(x)
824

825
826
827
        assert (
            inputs_q.dtype == inputs_kv.dtype == out.dtype
        ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
828
829
830
831
832
        return out


class LayerNorm(nn.Module):
    """T5 Layer normalization operating on the last axis of the input data."""
833

834
835
    epsilon: float = 1e-6
    dtype: Any = jnp.float32
836
    layernorm_type: str = "layernorm"
837
838
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
839
840
    bias_init: Initializer = nn.initializers.zeros

841
842
843
844
845
846
847
848
    def __post_init__(self):
        if self.scale_init is None:
            if not self.zero_centered_gamma:
                self.scale_init = nn.initializers.ones
            else:
                self.scale_init = nn.initializers.zeros
        super().__post_init__()

849
850
851
852
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Applies layer normalization on the input."""

853
        input_dtype = x.dtype
854
855
        features = x.shape[-1]

856
        scale = nn_partitioning.param_with_axes(
857
            "scale", self.scale_init, (features,), self.dtype, axes=("embed",)
858
        )
859
        x_ = x.astype(jnp.float32)
860
        if self.layernorm_type == "layernorm":
861
862
863
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
864

865
            bias = nn_partitioning.param_with_axes(
866
                "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
867
            )
868
            bias = jnp.asarray(bias, input_dtype)
869

870
871
872
            if not self.zero_centered_gamma:
                z = y * scale + bias
            else:
873
                z = y * (scale + 1.0) + bias
874
        else:
875
            assert self.layernorm_type == "rmsnorm"
876
            assert not self.zero_centered_gamma
877
878
            mean2 = jnp.mean(lax.square(x_), axis=-1, keepdims=True)
            y = x_ * lax.rsqrt(mean2 + self.epsilon)
879
            z = y * scale
880
        z = z.astype(input_dtype)
881

882
883
        assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
        return z
884
885
886
887
888


class RelativePositionBiases(nn.Module):
    """Adds T5-style relative positional embeddings to the attention logits.

889
890
891
892
893
894
895
    Attributes:
      num_buckets: Number of buckets to bucket distances between key and query
        positions into.
      max_distance: Maximum distance before everything is lumped into the last
        distance bucket.
      num_heads: Number of heads in the attention layer. Each head will get a
        different relative position weighting.
896
      dtype: the data type used to allocate the initial parameters (default: float32).
897
898
899
      embedding_init: initializer for relative embedding table.
    """

900
901
902
903
904
905
906
    num_buckets: int
    max_distance: int
    num_heads: int
    dtype: Any
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init

    @staticmethod
907
908
909
    def _relative_position_bucket(
        relative_position, bidirectional=True, num_buckets=32, max_distance=128
    ):
910
911
        """Translate relative position to a bucket number for relative attention.

912
913
914
915
916
917
918
919
920
921
        The relative position is defined as memory_position - query_position, i.e.
        the distance in tokens from the attending position to the attended-to
        position.  If bidirectional=False, then positive relative positions are
        invalid.
        We use smaller buckets for small absolute relative_position and larger
        buckets for larger absolute relative_positions.  All relative
        positions >=max_distance  map to the same bucket.  All relative
        positions <=-max_distance map to the same bucket.  This should allow for
        more graceful generalization to longer sequences than the model has been
        trained on.
922

923
924
925
926
927
        Args:
          relative_position: an int32 array
          bidirectional: a boolean - whether the attention is bidirectional
          num_buckets: an integer
          max_distance: an integer
928

929
930
931
932
        Returns:
          a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
933
934
935
936
937
938
939
940
941
942
943
944
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).astype(np.int32) * num_buckets
            n = np.abs(n)
        else:
            n = np.maximum(n, 0)
        # now n is in the range [0, inf)
        max_exact = num_buckets // 2
        is_small = n < max_exact
        val_if_large = max_exact + (
945
946
947
948
            np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
            / np.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).astype(np.int32)
949
950
951
952
953
954
955
956
        val_if_large = np.minimum(val_if_large, num_buckets - 1)
        ret += np.where(is_small, n, val_if_large)
        return ret

    @nn.compact
    def __call__(self, qlen, klen, bidirectional=True):
        """Produce relative position embedding attention biases.

957
958
959
960
961
        Args:
          qlen: attention query length.
          klen: attention key length.
          bidirectional: whether to allow positive memory-query relative position
            embeddings.
962

963
964
965
        Returns:
          output: `(1, len, q_len, k_len)` attention bias
        """
966
967
        context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
968
969
970
971
972
973
974
        relative_position = memory_position - context_position  # shape (qlen, klen)
        rp_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
        )
975
        relative_attention_bias = nn_partitioning.param_with_axes(
976
977
978
            "rel_embedding",
            self.embedding_init,
            (self.num_heads, self.num_buckets),
979
            jnp.float32,
980
981
            axes=("heads", "relpos_buckets"),
        )
982
983
984
985
986
987
988
989
990
991
992
993
994

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
        # Instead of using a slow gather, we create a leading-dimension one-hot
        # array from rp_bucket and use it to perform the gather-equivalent via a
        # contraction, i.e.:
        # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
        # This is equivalent to relative_attention_bias[:, rp_bucket]
        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
        # --> shape (qlen, klen, num_heads)
        values = lax.dot_general(
            relative_attention_bias,
            rp_bucket_one_hot,
995
996
            (((1,), (0,)), ((), ())),  # rhs, lhs contracting dims
        )  # no batched dims
997
998
999
1000
1001
        # Add a singleton batch dimension.
        # --> shape (1, num_heads, qlen, klen)
        return values[jnp.newaxis, ...]


1002
1003
1004
1005
1006
1007
def apply_swa_mask(
    attn_mask_type: str,
    original_mask: Array,
    window_size: Tuple[int, int] = (-1, -1),
) -> Array:
    """Apply the sliding window mask to a given mask"""
1008
    _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
1009
    assert _attn_mask_type is not None
1010
    batch = original_mask.shape[0]
1011
1012
    max_seqlen_q = original_mask.shape[-2]
    max_seqlen_kv = original_mask.shape[-1]
1013
1014
1015
    pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
    pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
    swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
1016
    # In swa_mask and original_mask 0 is masked out
1017
    new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
1018
1019
1020
    return new_mask


1021
1022
class EncoderLayer(nn.Module):
    """Transformer encoder layer."""
1023

1024
    enable_relative_embedding: bool = True
1025
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1026
1027
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1028
    head_dim: int = 64
1029
1030
1031
1032
1033
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1034
1035
1036
1037
1038
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1039
    mlp_activations: Sequence[str] = ("relu",)
1040
    use_bias: bool = False
1041
1042
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
1043
    layernorm_type: str = "layernorm"
1044
    layernorm_epsilon: float = 1e-6
1045
    zero_centered_gamma: bool = False
1046
1047
    output_layernorm: bool = False
    drop_path: float = 0.0
1048
    enable_rotary_pos_emb: bool = False
1049
    rotary_pos_emb_group_method: str = "consecutive"
1050
    fuse_qkv_params: bool = True
1051
1052
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1053
1054
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1055

zlsh80826's avatar
zlsh80826 committed
1056
1057
1058
1059
1060
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1061
1062
    @nn.compact
    def __call__(self, inputs, encoder_mask=None, deterministic=False):
1063
1064
1065
1066
1067
1068
1069
        # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
        encoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            encoder_mask,
            self.window_size,
        )

1070
1071
1072
1073
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1074
1075
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1076
1077
1078
1079
1080
1081
1082
1083
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1084
1085
1086
            else:
                rel_emb = self.relative_embedding
            encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
1087
        else:
1088
            encoder_bias = None
1089
1090
1091
1092
1093
1094

        # Attention block.
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1095
1096
1097
1098
1099
1100
1101
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_attention_layer_norm",
            )(inputs)
1102
1103
1104
1105
1106
1107
1108

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            fuse_qkv=self.fuse_qkv_params,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            use_bias=self.use_bias,
            name="attention",
        )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1128
1129
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1130
1131
1132
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1133
1134
1135
1136
        x = x + residual

        # MLP block.
        residual = x
1137
1138
1139
1140
1141
1142
1143
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(x)
1144
1145
1146
1147
1148
1149
1150
1151
1152

        if self.apply_residual_connection_post_layernorm:
            residual = y

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        y = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1153
1154
1155
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1156
1157
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1158
            name="mlp",
1159
        )(y, deterministic=deterministic)
1160

1161
1162
1163
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1164

1165
1166
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
1167
1168
1169
            y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                y, deterministic=deterministic
            )
1170
1171
1172
        y = y + residual

        if self.output_layernorm:
1173
1174
1175
1176
1177
1178
1179
            y = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(y)
1180

1181
        assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
1182
1183
1184
1185
1186
        return y


class DecoderLayer(nn.Module):
    """Transformer decoder layer that attends to the encoder."""
1187

1188
    enable_relative_embedding: bool = True
1189
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1190
1191
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1192
    head_dim: int = 64
1193
1194
1195
1196
1197
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1198
1199
1200
1201
1202
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1203
    mlp_activations: Sequence[str] = ("relu",)
1204
    use_bias: bool = False
1205
1206
1207
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
1208
    layernorm_type: str = "layernorm"
1209
    layernorm_epsilon: float = 1e-6
1210
    zero_centered_gamma: bool = False
1211
    drop_path: float = 0.0
1212
    enable_rotary_pos_emb: bool = False
1213
    rotary_pos_emb_group_method: str = "consecutive"
1214
    fuse_qkv_params: bool = True
1215
1216
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1217
1218
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1219

zlsh80826's avatar
zlsh80826 committed
1220
1221
1222
1223
1224
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1225
    @nn.compact
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
    def __call__(
        self,
        inputs,
        encoded,
        decoder_mask=None,
        encoder_decoder_mask=None,
        deterministic=False,
        decode=False,
        max_decode_length=None,
    ):
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        decoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            decoder_mask,
            self.window_size,
        )

        encoder_decoder_mask = apply_swa_mask(
            "padding",
            encoder_decoder_mask,
            self.window_size,
        )

1248
1249
1250
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim
1251
1252
1253
1254

        if self.enable_relative_embedding:
            l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
            if self.relative_embedding is None:
1255
1256
1257
1258
1259
1260
1261
1262
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1263
1264
1265
            else:
                rel_emb = self.relative_embedding
            decoder_bias = rel_emb(l, l, False)
1266
        else:
1267
            decoder_bias = None
1268
1269
1270
1271
1272
1273

        # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1274
1275
1276
1277
1278
1279
1280
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_self_attention_layer_norm",
            )(inputs)
1281
1282
1283
1284
1285
1286
1287

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # Self-attention block
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="self_attention",
        )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1307
1308
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1309
1310
1311
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1312
1313
1314
1315
        x = x + residual

        # Encoder-Decoder block.
        residual = x
1316
1317
1318
1319
1320
1321
1322
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_cross_attention_layer_norm",
        )(x)
1323
1324
1325

        if self.apply_residual_connection_post_layernorm:
            residual = y
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
        y = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="encoder_decoder_attention",
        )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1345
1346
1347
1348
        y = y + residual

        # MLP block.
        residual = y
1349
1350
1351
1352
1353
1354
1355
        z = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(y)
1356
1357
1358
1359
1360
1361
        if self.apply_residual_connection_post_layernorm:
            residual = z
        z = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1362
1363
1364
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1365
1366
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1367
            name="mlp",
1368
        )(z, deterministic=deterministic)
1369
1370
1371
        z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            z, deterministic=deterministic
        )
1372
1373
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
1374
1375
1376
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
1377
1378
1379
        z = z + residual

        if self.output_layernorm:
1380
1381
1382
1383
1384
1385
1386
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
1387

1388
        assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
1389
1390
1391
        return z


1392
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1393
1394
1395
    """
    Generate causal mask
    """
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
    shape = (batch, seqlen)
    idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)

    mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


def make_self_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1406
1407
1408
    """
    Generate attention mask
    """
1409
1410
1411
1412
1413
1414
1415
    shape = (batch, seqlen)
    mask = jnp.ones((*shape, shape[-1]))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
def assert_allclose(
    actual: Array,
    desired: Array,
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
    dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
    **kwargs,
) -> None:
    """Check if two tensors are close.

    Args:
      actual: test tensor.
      desired: reference tensor.
      dtype: data type or data type name (default: inferred from
        `actual`).
      rtol: relative tolerance (default: based on `dtype`).
      atol: absolute tolerance (default: based on `dtype`).
      **kwargs: keyword arguments to pass to np.testing.assert_allclose.
    """

    # Infer data type if needed
    if dtype is None:
        if isinstance(actual, float):
            dtype = "float32"
        else:
            dtype = actual.dtype

    # Determine tolerances
zlsh80826's avatar
zlsh80826 committed
1444
    tols = {}
1445
1446
1447
1448
1449
1450
1451
1452
    if rtol is None or atol is None:
        tols = dtype_tols(dtype)
    if rtol is not None:
        tols["rtol"] = rtol
    if atol is not None:
        tols["atol"] = atol

    # Cast tensors to fp32
1453
1454
1455
1456
    if not isinstance(actual, float):
        actual = actual.astype(jnp.float32)
    if not isinstance(desired, float):
        desired = desired.astype(jnp.float32)
1457
1458
1459
1460
1461

    # Check if tensors are close
    np.testing.assert_allclose(actual, desired, **tols, **kwargs)


1462
1463
1464
1465
def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
    flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
    flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)

1466
1467
1468
    for (expected_path, expected_value), (actual_path, actual_value) in zip(
        flatten_expected, flatten_actual
    ):
1469
1470
        assert expected_path == actual_path
        key_str = jax.tree_util.keystr(expected_path)
1471
1472
1473
1474
1475
1476
1477
        assert_allclose(
            expected_value,
            actual_value,
            rtol=rtol,
            atol=atol,
            err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
        )
1478
1479


1480
1481
1482
def dtype_tols(
    dtype: Union[DType, TEDType, np.dtype],
    reference_value: float = 1.0,
1483
1484
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
1485
1486
1487
1488
1489
1490
) -> Dict[str, float]:
    """Expected numerical tolerance for a data type.

    Args:
      dtype: data type.
      reference_value: reference value (default: 1).
1491
1492
      rtol: override for relative tolerance estimate
      atol: override for absolute tolerance estimate
1493
1494
1495
1496
1497
1498

    Returns:
      Dictionary with "rtol" and "atol" as keys

    """

1499
1500
1501
1502
    # Return immediately if tolerances are fully specified
    if rtol is not None and atol is not None:
        return {"rtol": rtol, "atol": atol}

1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    # Convert to JAX dtype if needed
    if isinstance(dtype, TEDType):
        dtype = {
            TEDType.kByte: jnp.uint8,
            TEDType.kInt32: jnp.int32,
            TEDType.kInt64: jnp.int64,
            TEDType.kFloat32: jnp.float32,
            TEDType.kFloat16: jnp.float16,
            TEDType.kBFloat16: jnp.bfloat16,
            TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
            TEDType.kFloat8E5M2: jnp.float8_e5m2,
        }[dtype]
    elif isinstance(dtype, np.dtype):
        dtype = jnp.dtype(dtype)

    # Expect bit-wise accuracy for integer dtypes
    if not jnp.issubdtype(dtype, jnp.floating):
1520
1521
1522
1523
1524
        if rtol is None:
            rtol = 0.0
        if atol is None:
            atol = 0.0
        return {"rtol": rtol, "atol": atol}
1525
1526
1527

    # Estimate floating-point error
    finfo = jnp.finfo(dtype)
1528
    eps_relaxed = math.pow(finfo.eps, 2 / 3)
1529
1530
1531
1532
1533
1534
1535
1536
    with jax.default_device(jax.devices("cpu")[0]):
        if isinstance(reference_value, (float, int)):
            reference_value = jnp.array(reference_value, dtype=dtype)
        else:
            reference_value = reference_value.astype(dtype)
        spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
        spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
        ulp = max(spacing_high.item(), spacing_low.item())
1537
1538
1539
1540
1541
    if rtol is None:
        rtol = eps_relaxed
    if atol is None:
        atol = max(ulp, eps_relaxed)
    return {"rtol": rtol, "atol": atol}
1542
1543


1544
def sync_params_values(dst, src, transformations, sep="/"):
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
    """
    This function will reconstuct a tree with dst's tree_def/shape and src's value.
    transformations is a map that records the key mappings between dst and src.
    If no dst key found in the transformerations, it will fall back to src key = dst key.
    transformations = {
        dst key map 0: src key map 0,
        dst key map 1: src key map 1,
        ...
        # if dst key = src key, we don't need to add it
    }
    """
    src_values = {}
    for key, value in jax.tree_util.tree_leaves_with_path(src):
        normalized_key = sep.join(x.key for x in key)
        src_values[normalized_key] = value

    flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
    synced_dst_values = []

    for key, value in flatten_dst:
        normalized_key = sep.join(x.key for x in key)
        if normalized_key in transformations:
            corresponding_src_key = transformations[normalized_key]
        else:
            corresponding_src_key = normalized_key
        synced_dst_values.append(src_values[corresponding_src_key])

    synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)

    return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594


@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
    if NVTE_DEBUG_NUMERICS:
        args = [
            jnp.mean(tensor),
            jnp.min(tensor),
            jnp.max(tensor),
            jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
            jnp.count_nonzero(tensor),
        ]
        fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"

        if hist:
            h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
            args += [h[0], h[1]]
            fmt = fmt + "\n  {}\n  {}"

        jax.debug.print(fmt, *args)