utils.py 60.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
zlsh80826's avatar
zlsh80826 committed
4
"""Utility for the TE layer tests"""
5

6
import functools
7
import math
8
import operator
9
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
10
import os
11

12
import jax
13
14
import jax.numpy as jnp
import numpy as np
15
16
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
17
from flax.linen.attention import combine_masks
18
19
20
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
21
import pytest
22

23
24
25
26
27
from transformer_engine.jax.attention import (
    AttnMaskType,
    canonicalize_attn_mask_type,
    make_swa_mask,
)
28
from transformer_engine.jax.quantize.helper import DType as TEDType
29

30
31
32
33
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
34
35
36
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
37
38
Initializer = Callable[[PRNGKey, Shape, DType], Array]

39
40
41
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))

42

43
def is_devices_enough(required):
zlsh80826's avatar
zlsh80826 committed
44
45
46
    """
    Check if the available GPUs is enough
    """
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    return len(jax.devices()) >= required


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
70
    if fn_or_string == "linear":
71
72
73
74
75
76
77
78
79
80
81
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases.

82
83
    Args:
      *masks: set of attention bias arguments to combine, some can be None.
84

85
86
87
    Returns:
      Combined mask, reduced by summation, returns None if no masks given.
    """
88
89
90
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
91
92
93
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
94
95
96
97
98
99
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


100
def get_parameters_for_test_level(param_dict: dict):
101
102
    """
    Takes an input dictionary of parameters keyed by test type "L0", etc.
103
    Returns the parameters for the test level specified in the environment variable
104
105
106
107
108
    """
    DEFAULT_TEST_LEVEL = "L0"
    test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
    if test_level not in param_dict:
        raise ValueError("Unsupported test level")
109
    return param_dict[test_level]
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142


def value_to_test_name_str(value):
    """Converts a value to how it should appear in a test name."""
    if isinstance(value, tuple) or isinstance(value, list):
        return "_".join([value_to_test_name_str(v) for v in value])

    dtype_type = type(jnp.float32)
    if isinstance(value, dtype_type):
        return value.dtype

    return str(value)


def value_to_named_param(value, id_prefix: str = ""):
    param_type = type(pytest.param(0))
    if isinstance(value, param_type):
        return value

    x = pytest.param(value, id=f"{id_prefix}_{value_to_test_name_str(value)}")
    return x


def values_to_named_params(params, id_prefix: str = ""):
    return [value_to_named_param(v, id_prefix=id_prefix) for v in params]


def pytest_parametrize_wrapper(param_name, param_values):
    """
    A wrapper for pytest.mark.parametrize to allow for automatic
    naming of tests based on the parameter values.
    """
    if isinstance(param_values, dict):
143
144
145
146
147
148
149
150
151
152
        # If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
        # unwrap the selected level before proceeding.
        param_values = get_parameters_for_test_level(param_values)

    if "," not in param_name:
        # Multi-parameterize annotations are not supported in this wrapper
        # and are just a passthrough to default pytest.mark.parametrize.
        # E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
        # will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
        param_values = values_to_named_params(param_values, id_prefix=param_name)
153
154
155
156
157
158
159

    def decorator(func):
        return pytest.mark.parametrize(param_name, param_values)(func)

    return decorator


160
161
162
class DotProductAttention(nn.Module):
    transpose_batch_sequence: bool = True
    scale_attn_logits: bool = True
163
    dropout_rate: float = 0.0
164
165
    dtype: DType = jnp.float32
    float32_logits: bool = False
166
167
    """Computes dot-product attention given query, key, and value.

168
169
170
    This is the core function for applying attention based on
    https://arxiv.org/abs/1706.03762. It calculates the attention weights given
    query and key and combines the values using the attention weights.
171

172
173
    Args:
        dropout_rate: dropout rate
174
        dtype: the data type used to allocate the initial parameters (default: float32).
175
176
177
        float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
    """
178

179
    @nn.compact
180
181
182
183
184
185
186
187
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        bias: Optional[Array] = None,
        deterministic: bool = False,
    ):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        """
        Args:
            query: queries for calculating attention with shape of `[batch, q_length,
            num_heads, qk_depth_per_head]`.
            key: keys for calculating attention with shape of `[batch, kv_length,
            num_gqa_groups, qk_depth_per_head]`.
            value: values to be used in attention with shape of `[batch, kv_length,
            num_gqa_groups, v_depth_per_head]`.
            bias: bias for the attention weights. This should be broadcastable to the
            shape `[batch, num_heads, q_length, kv_length]` This can be used for
            incorporating causal masks, padding masks, proximity bias, etc.
            dropout_rng: JAX PRNGKey: to be used for dropout
            deterministic: bool, deterministic or not (to apply dropout)
        Returns:
            Output of shape `[batch, length, num_heads, v_depth_per_head]`.
        """
204
        input_dtype = query.dtype
205
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
206
        batch_dim = 1 if self.transpose_batch_sequence else 0
207
208
209
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
210
        sequence_dim = 0 if self.transpose_batch_sequence else 1
211
212
213
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
214
215
216

        if self.scale_attn_logits:
            head_dim = query.shape[-1]
217
            depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
218
219
220
221
222
223
224
225
226
227
228
229
230
231
            query = query / depth_scaling

        # Casting logits and softmax computation for float32 for model stability.
        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)

        # `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
        h_q, h_kv = query.shape[-2], key.shape[-2]
        assert (h_q % h_kv == 0) and (h_q >= h_kv)
        group_size = h_q // h_kv
        grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
232
            attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
233
        else:
234
            attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
235
236
237
238
239
240
241
242
243
244
245

        # reshape back to normal DPA shape for bias/softmax/dropout
        b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
        attn_weights_without_groups_shape = (b, h * g, q, k)
        attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        # Apply attention bias: masking, dropout, proximity bias, etc.
        if bias is not None:
            attn_weights = attn_weights + bias.astype(attn_weights.dtype)

        # Normalize the attention weights across `kv_length` dimension.
246
        attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
247
248

        # Apply attention dropout.
249
        if not deterministic and self.dropout_rate > 0.0:
250
251
252
253
            keep_prob = 1.0 - self.dropout_rate
            # T5 broadcasts along the "length" dim, but unclear which one that
            # corresponds to in positional dimensions here, assuming query dim.
            dropout_shape = list(attn_weights.shape)
254
            dropout_rng = self.make_rng("dropout")
255
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
256
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
257
258
259
            attn_weights = attn_weights * multiplier

        attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
260
        # attn_weights = attn_weights.astype(input_dtype)
261
262
263

        # Take the linear combination of `value`.
        if self.transpose_batch_sequence:
264
            return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
265

266
267
268
269
        assert (
            attn_weights.dtype == input_dtype
        ), f"input.dtype={input_dtype}, output.dtype={attn_weights.dtype}"

270
        return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
271
272
273
274
275


class DenseGeneral(nn.Module):
    """A linear transformation with flexible axes and FP8 support.

276
277
278
    Attributes:
    features: tuple with numbers of output features.
    axis: tuple with axes to apply the transformation on.
279
    dtype: the data type used to allocate the initial parameters (default: float32).
280
281
282
    kernel_init: initializer function for the weight matrix.
    use_bias: whether to add a bias to the output (default: False).
    bias_init: initializer function for the bias vector.
283
    """
284

285
286
287
288
289
290
291
292
293
294
295
    features: Union[Iterable[int], int]
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()

    def __post_init__(self):
        if self.kernel_init is None:
296
297
298
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
299
300
301
302
303
304
305
306
307
308
309
310
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

        Args:
        inputs: The nd-array to be transformed.

        Returns:
        The transformed input.
        """
311
        input_dtype = inputs.dtype
312
313
314
315
316
317
318
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
319
        kernel = nn_partitioning.param_with_axes(
320
            "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
321
        )
322

323
        kernel = jnp.asarray(kernel, input_dtype)
324
325
326
        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
327
            bias = nn_partitioning.param_with_axes(
328
                "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
329
            )
330
            bias = bias.astype(input_dtype)
331
332
333
334
335
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))

336
337
338
        y = lax.dot_general(
            inputs, kernel, ((axis, contract_ind), ((), ())), preferred_element_type=input_dtype
        )
339
340
341

        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
342
343

        assert y.dtype == inputs.dtype, f"input.dtype={inputs.dtype}, output.dtype={y.dtype}"
344
345
346
347
348
349
        return y


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block.

350
351
352
353
354
355
356
    Attributes:
      intermediate_dim: Shared dimension of hidden layers.
      activations: Type of activations for each layer.  Each element is either
        'linear', a string function name in flax.linen, or a function.
      kernel_init: Kernel function, passed to the dense layers.
      deterministic: Whether the dropout layers should be deterministic.
      intermediate_dropout_rate: Dropout rate used after the intermediate layers.
357
      dtype: the data type used to allocate the initial parameters (default: float32).
358
359
    """

360
361
    transpose_batch_sequence: bool
    intermediate_dim: int = 2048
362
    activations: Sequence[Union[str, Callable]] = ("relu",)
363
364
    kernel_init: Initializer = None
    intermediate_dropout_rate: float = 0.1
365
366
    intermediate_dropout_dims: Sequence[int] = ()
    use_bias: bool = False
367
    dtype: Any = jnp.float32
368
    fuse_wi: bool = True
369
370
371

    def __post_init__(self):
        if self.kernel_init is None:
372
373
374
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
375
376
377
378
379
380
381
382
383
384
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs, deterministic: bool = False):
        """Applies Transformer MlpBlock module."""
        # Iterate over specified MLP input activation functions.
        # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.

        activations = []
        if self.fuse_wi:
385
            dense_name = "wi"
386
            num_activations = len(self.activations)
387
388
389
390
391
392
393
394
395
            x = DenseGeneral(
                self.intermediate_dim * num_activations,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                kernel_axes=("embed", "mlp"),
                use_bias=self.use_bias,
                bias_axes="mlp",
                name=dense_name,
            )(inputs)
396
397
398
399
400
401
            x = jnp.split(x, num_activations, axis=-1)
            for idx, act_fn in enumerate(self.activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                activations.append(x_i)
        else:
            for idx, act_fn in enumerate(self.activations):
402
403
404
405
406
407
408
409
410
411
                dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
                x = DenseGeneral(
                    self.intermediate_dim,
                    dtype=self.dtype,
                    kernel_init=self.kernel_init,
                    kernel_axes=("embed", "mlp"),
                    use_bias=self.use_bias,
                    bias_axes="mlp",
                    name=dense_name,
                )(inputs)
412
413
414
415
416
417
                x = _convert_to_activation_function(act_fn)(x)
                activations.append(x)

        # Take elementwise product of above intermediate activations.
        x = functools.reduce(operator.mul, activations)
        # Apply dropout and final dense output projection.
418
419
420
421
422
        x = nn.Dropout(
            rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
        )(
            x, deterministic=deterministic
        )  # Broadcast along length.
423

424
        if self.transpose_batch_sequence:
425
            x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
426
        else:
427
428
429
430
431
432
433
434
435
436
            x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
        output = DenseGeneral(
            inputs.shape[-1],
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            kernel_axes=("mlp", "embed"),
            use_bias=self.use_bias,
            bias_axes="embed",
            name="wo",
        )(x)
437

438
439
440
        assert (
            output.dtype == inputs.dtype
        ), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
441
442
443
        return output


444
def apply_rotary_pos_emb_alternate(
445
446
447
448
449
450
451
452
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
453
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
454
455
456
457
458
459
460
461
462
463
    timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    first_half, second_half = jnp.split(inputs, 2, axis=-1)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    first_part = first_part.astype(inputs.dtype)
    second_part = second_part.astype(inputs.dtype)
464
    return jnp.concatenate([first_part, second_part], axis=-1).astype(inputs.dtype)
465
466


467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def apply_rotary_pos_emb_consecutive(
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim

    inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
    inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1)
    inputs_shifted = jax.lax.select(
        jnp.tile(
            jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2),
            inputs.shape[:-1] + (1,),
        ),
        inputs_shifted_right,
        inputs_shifted_left,
    )
    fraction = jnp.repeat(fraction, 2)
488
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
489
490
491
492
493
494
495
496
497

    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))

    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
    outputs = inputs * cos + inputs_shifted * sin * sign

498
    return outputs.astype(inputs.dtype)
499
500


501
502
503
504
505
506
507
508
509
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class MultiHeadAttention(nn.Module):
    """Multi-head dot-product attention.

    Attributes:
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
zlsh80826's avatar
zlsh80826 committed
510
      num_gqa_groups: number of kv attention heads
511
      head_dim: dimension of each head.
512
      dtype: the data type used to allocate the initial parameters (default: float32).
513
514
515
516
      dropout_rate: dropout rate
      kernel_init: initializer for the kernel of the Dense layers.
      float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
517
    """
518

zlsh80826's avatar
zlsh80826 committed
519
520
521
522
    num_heads: int = 8
    num_gqa_groups: int | None = None
    head_dim: int = 64
    transpose_batch_sequence: bool = True
523
    dtype: DType = jnp.float32
524
    dropout_rate: float = 0.0
525
    kernel_init: Initializer = None
526
    float32_logits: bool = False  # computes logits in float32 for stability.
527
528
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
529
    enable_rotary_pos_emb: bool = False
530
    rotary_pos_emb_group_method: str = "consecutive"
531
    fuse_qkv: bool = True
532
    use_bias: bool = False
533
534
535

    def __post_init__(self):
        if self.kernel_init is None:
536
537
538
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "normal", dtype=self.dtype
            )
zlsh80826's avatar
zlsh80826 committed
539
540
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
541
542
543
        super().__post_init__()

    @nn.compact
544
545
546
547
548
549
550
551
552
553
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
554
555
        """Applies multi-head dot product attention on the input data.

556
557
        Projects the inputs into multi-headed query, key, and value vectors,
        applies dot-product attention and project the results to an output vector.
558

559
560
561
562
563
564
        There are two modes: decoding and non-decoding (e.g., training). The mode is
        determined by `decode` argument. For decoding, this method is called twice,
        first to initialize the cache and then for an actual decoding process. The
        two calls are differentiated by the presence of 'cached_key' in the variable
        dict. In the cache initialization stage, the cache variables are initialized
        as zeros and will be filled in the subsequent decoding process.
565

566
567
568
569
        In the cache initialization call, `inputs_q` has a shape [batch, length,
        q_features] and `inputs_kv`: [batch, length, kv_features]. During the
        incremental decoding stage, query, key and value all have the shape [batch,
        1, qkv_features] corresponding to a single step.
570

571
572
573
574
575
576
577
        Args:
          inputs_q: input queries of shape `[batch, q_length, q_features]`.
          inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
          mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
          bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
          decode: Whether to prepare and use an autoregressive cache.
          deterministic: Disables dropout if set to True.
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        Returns:
          output of shape `[batch, length, q_features]`.
        """
        q_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_heads * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )

        kv_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_gqa_groups * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )
601
602
603
604
605

        # NOTE: T5 does not explicitly rescale the attention logits by
        #       1/sqrt(depth_kq)!  This is folded into the initializers of the
        #       linear transformations, which is equivalent under Adafactor
        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
606
607
608
        query_init = lambda *args: self.kernel_init(*args) / (
            depth_scaling if self.scaled_query_init else 1.0
        )
609
610
611
612
613
614
615
616
617
618
619
620

        # Project inputs_q to multi-headed q/k/v
        # dimensions are then [batch, length, num_heads, head_dim]

        def qkv_init(key, shape, dtype):
            assert shape[-1] % 3 == 0

            q_shape = (shape[0], shape[1] // 3)
            k_shape = (shape[0], shape[1] // 3)
            v_shape = (shape[0], shape[1] // 3)

            q_kernel = query_init(key, q_shape, dtype)
zlsh80826's avatar
zlsh80826 committed
621
622
            k_kernel = self.kernel_init(key, k_shape, dtype)
            v_kernel = self.kernel_init(key, v_shape, dtype)
623
624
625

            return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)

626
627
628
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
zlsh80826's avatar
zlsh80826 committed
629

630
        if self.fuse_qkv:
zlsh80826's avatar
zlsh80826 committed
631
            if is_qkvpack:
632

633
634
635
636
637
638
639
640
641
642
                qkv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_heads * self.head_dim * 3,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_kv)
643

644
                query, key, value = jnp.split(
645
646
647
648
                    qkv_proj,
                    [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
                    axis=-1,
                )
649

650
            else:
651
652
653
654
655
656
657
658
659
660
661
662
                query = q_projection(kernel_init=query_init, name="query")(inputs_q)

                kv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_gqa_groups * self.head_dim * 2,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=self.kernel_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
zlsh80826's avatar
zlsh80826 committed
663
                key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
664
        else:
665
666
667
            query = q_projection(kernel_init=query_init, name="query")(inputs_q)
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
668

669
670
671
672
        if self.enable_rotary_pos_emb:
            batch_dim = 1 if self.transpose_batch_sequence else 0
            seq_dim = 1 - batch_dim

673
674
            q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
            k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
675

676
            if self.rotary_pos_emb_group_method == "alternate":
677
678
679
680
                apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
            else:
                apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive

681
682
683
684
            query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            query = apply_rotary_pos_emb(query, q_position)
            key = apply_rotary_pos_emb(key, k_position)
685

zlsh80826's avatar
zlsh80826 committed
686
687
688
        query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
        key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
        value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
689
690

        if self.transpose_batch_sequence:
691
692
693
694
695
696
697
            query = nn_partitioning.with_sharding_constraint(
                query, ("length", "batch", "heads", "kv")
            )
            key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
            value = nn_partitioning.with_sharding_constraint(
                value, ("length", "batch", "heads", "kv")
            )
698
        else:
699
700
701
702
703
704
705
            query = nn_partitioning.with_sharding_constraint(
                query, ("batch", "length", "heads", "kv")
            )
            key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
            value = nn_partitioning.with_sharding_constraint(
                value, ("batch", "length", "heads", "kv")
            )
706
707
708

        if decode:
            # Detect if we're initializing by absence of existing cache data.
709
            is_initialized = self.has_variable("cache", "cached_key")
710
711
712
713
714
            # The key and value have dimension [batch, length, num_heads, head_dim],
            # but we cache them as [batch, num_heads, head_dim, length] as a TPU
            # fusion optimization. This also enables the "scatter via one-hot
            # broadcast" trick, which means we do a one-hot broadcast instead of a
            # scatter/gather operations, resulting in a 3-4x speedup in practice.
zlsh80826's avatar
zlsh80826 committed
715
            swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
716
717
718
719
720
721
722
723
724
            cached_key = self.variable(
                "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
            )
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
725
726
727
728
729
730
731
732
            if is_initialized:
                batch, num_heads, head_dim, length = cached_key.value.shape
                # During fast autoregressive decoding, we feed one position at a time,
                # and cache the keys and values step by step.
                # Sanity shape check of cached key against input query.
                expected_shape = (batch, 1, num_heads, head_dim)
                if expected_shape != query.shape:
                    raise ValueError(
733
734
735
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762

                # Create a OHE of the current index. NOTE: the index is increased below.
                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
                # In order to update the key, value caches with the current key and
                # value, we move the length axis to the back, similar to what we did for
                # the cached ones above.
                # Note these are currently the key and value of a single position, since
                # we feed one position at a time.
                one_token_key = jnp.moveaxis(key, -3, -1)
                one_token_value = jnp.moveaxis(value, -3, -1)
                # Update key, value caches with our new 1d spatial slices.
                # We implement an efficient scatter into the cache via one-hot
                # broadcast and addition.
                key = cached_key.value + one_token_key * one_hot_indices
                value = cached_value.value + one_token_value * one_hot_indices
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # Move the keys and values back to their original shapes.
                key = jnp.moveaxis(key, -1, -3)
                value = jnp.moveaxis(value, -1, -3)

                # Causal mask for cached decoder self-attention: our single query
                # position should only attend to those key positions that have already
                # been generated and cached, not the remaining zero elements.
                mask = combine_masks(
763
                    jnp.logical_not(mask),
764
765
                    jnp.broadcast_to(
                        jnp.arange(length) <= cur_index,
766
767
768
769
770
771
772
                        # (1, 1, length) represent (head dim, query length, key length)
                        # query length is 1 because during decoding we deal with one
                        # index.
                        # The same mask is applied to all batch elements and heads.
                        (batch, 1, 1, length),
                    ),
                )
773
774
775
776
777
778
779

                # Grab the correct relative attention bias during decoding. This is
                # only required during single step decoding.
                if bias is not None:
                    # The bias is a full attention matrix, but during decoding we only
                    # have to take a slice of it.
                    # This is equivalent to bias[..., cur_index:cur_index+1, :].
780
781
782
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
783
784
785
786

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
787

788
789
790
791
792
            attention_bias = lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.0).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype),
            )
793
794
795
796
797
798
799
800
        else:
            attention_bias = None

        # Add provided bias term (e.g. relative position embedding).
        if bias is not None:
            attention_bias = combine_biases(attention_bias, bias)

        # Apply attention.
801
802
803
804
805
806
807
        x = DotProductAttention(
            transpose_batch_sequence=self.transpose_batch_sequence,
            scale_attn_logits=self.scale_attn_logits,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            float32_logits=self.float32_logits,
        )(query, key, value, bias=attention_bias, deterministic=deterministic)
808
809
810
811

        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        if self.transpose_batch_sequence:
812
            x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
813
        else:
814
            x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
815
816

        # Back to the original inputs dimensions.
817

818
        out = DenseGeneral(
819
            features=inputs_q.shape[-1],  # output dim is set to the input dim.
820
821
            axis=-1,
            kernel_init=self.kernel_init,
822
            kernel_axes=("joined_kv", "embed"),
823
            use_bias=self.use_bias,
824
            bias_axes="embed",
825
            dtype=self.dtype,
826
827
            name="out",
        )(x)
828

829
830
831
        assert (
            inputs_q.dtype == inputs_kv.dtype == out.dtype
        ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
832
833
834
835
836
        return out


class LayerNorm(nn.Module):
    """T5 Layer normalization operating on the last axis of the input data."""
837

838
839
    epsilon: float = 1e-6
    dtype: Any = jnp.float32
840
    layernorm_type: str = "layernorm"
841
842
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
843
844
    bias_init: Initializer = nn.initializers.zeros

845
846
847
848
849
850
851
852
    def __post_init__(self):
        if self.scale_init is None:
            if not self.zero_centered_gamma:
                self.scale_init = nn.initializers.ones
            else:
                self.scale_init = nn.initializers.zeros
        super().__post_init__()

853
854
855
856
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Applies layer normalization on the input."""

857
        input_dtype = x.dtype
858
859
        features = x.shape[-1]

860
        scale = nn_partitioning.param_with_axes(
861
            "scale", self.scale_init, (features,), self.dtype, axes=("embed",)
862
        )
863
        x_ = x.astype(jnp.float32)
864
        if self.layernorm_type == "layernorm":
865
866
867
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
868

869
            bias = nn_partitioning.param_with_axes(
870
                "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
871
            )
872
            bias = jnp.asarray(bias, input_dtype)
873

874
875
876
            if not self.zero_centered_gamma:
                z = y * scale + bias
            else:
877
                z = y * (scale + 1.0) + bias
878
        else:
879
            assert self.layernorm_type == "rmsnorm"
880
            assert not self.zero_centered_gamma
881
882
            mean2 = jnp.mean(lax.square(x_), axis=-1, keepdims=True)
            y = x_ * lax.rsqrt(mean2 + self.epsilon)
883
            z = y * scale
884
        z = z.astype(input_dtype)
885

886
887
        assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
        return z
888
889
890
891
892


class RelativePositionBiases(nn.Module):
    """Adds T5-style relative positional embeddings to the attention logits.

893
894
895
896
897
898
899
    Attributes:
      num_buckets: Number of buckets to bucket distances between key and query
        positions into.
      max_distance: Maximum distance before everything is lumped into the last
        distance bucket.
      num_heads: Number of heads in the attention layer. Each head will get a
        different relative position weighting.
900
      dtype: the data type used to allocate the initial parameters (default: float32).
901
902
903
      embedding_init: initializer for relative embedding table.
    """

904
905
906
907
908
909
910
    num_buckets: int
    max_distance: int
    num_heads: int
    dtype: Any
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init

    @staticmethod
911
912
913
    def _relative_position_bucket(
        relative_position, bidirectional=True, num_buckets=32, max_distance=128
    ):
914
915
        """Translate relative position to a bucket number for relative attention.

916
917
918
919
920
921
922
923
924
925
        The relative position is defined as memory_position - query_position, i.e.
        the distance in tokens from the attending position to the attended-to
        position.  If bidirectional=False, then positive relative positions are
        invalid.
        We use smaller buckets for small absolute relative_position and larger
        buckets for larger absolute relative_positions.  All relative
        positions >=max_distance  map to the same bucket.  All relative
        positions <=-max_distance map to the same bucket.  This should allow for
        more graceful generalization to longer sequences than the model has been
        trained on.
926

927
928
929
930
931
        Args:
          relative_position: an int32 array
          bidirectional: a boolean - whether the attention is bidirectional
          num_buckets: an integer
          max_distance: an integer
932

933
934
935
936
        Returns:
          a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
937
938
939
940
941
942
943
944
945
946
947
948
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).astype(np.int32) * num_buckets
            n = np.abs(n)
        else:
            n = np.maximum(n, 0)
        # now n is in the range [0, inf)
        max_exact = num_buckets // 2
        is_small = n < max_exact
        val_if_large = max_exact + (
949
950
951
952
            np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
            / np.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).astype(np.int32)
953
954
955
956
957
958
959
960
        val_if_large = np.minimum(val_if_large, num_buckets - 1)
        ret += np.where(is_small, n, val_if_large)
        return ret

    @nn.compact
    def __call__(self, qlen, klen, bidirectional=True):
        """Produce relative position embedding attention biases.

961
962
963
964
965
        Args:
          qlen: attention query length.
          klen: attention key length.
          bidirectional: whether to allow positive memory-query relative position
            embeddings.
966

967
968
969
        Returns:
          output: `(1, len, q_len, k_len)` attention bias
        """
970
971
        context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
972
973
974
975
976
977
978
        relative_position = memory_position - context_position  # shape (qlen, klen)
        rp_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
        )
979
        relative_attention_bias = nn_partitioning.param_with_axes(
980
981
982
            "rel_embedding",
            self.embedding_init,
            (self.num_heads, self.num_buckets),
983
            jnp.float32,
984
985
            axes=("heads", "relpos_buckets"),
        )
986
987
988
989
990
991
992
993
994
995
996
997
998

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
        # Instead of using a slow gather, we create a leading-dimension one-hot
        # array from rp_bucket and use it to perform the gather-equivalent via a
        # contraction, i.e.:
        # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
        # This is equivalent to relative_attention_bias[:, rp_bucket]
        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
        # --> shape (qlen, klen, num_heads)
        values = lax.dot_general(
            relative_attention_bias,
            rp_bucket_one_hot,
999
1000
            (((1,), (0,)), ((), ())),  # rhs, lhs contracting dims
        )  # no batched dims
1001
1002
1003
1004
1005
        # Add a singleton batch dimension.
        # --> shape (1, num_heads, qlen, klen)
        return values[jnp.newaxis, ...]


1006
1007
1008
1009
1010
1011
def apply_swa_mask(
    attn_mask_type: str,
    original_mask: Array,
    window_size: Tuple[int, int] = (-1, -1),
) -> Array:
    """Apply the sliding window mask to a given mask"""
1012
    _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
1013
    assert _attn_mask_type is not None
1014
    batch = original_mask.shape[0]
1015
1016
    max_seqlen_q = original_mask.shape[-2]
    max_seqlen_kv = original_mask.shape[-1]
1017
1018
1019
    pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
    pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
    swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
1020
    # In swa_mask and original_mask 0 is masked out
1021
    new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
1022
1023
1024
    return new_mask


1025
1026
class EncoderLayer(nn.Module):
    """Transformer encoder layer."""
1027

1028
    enable_relative_embedding: bool = True
1029
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1030
1031
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1032
    head_dim: int = 64
1033
1034
1035
1036
1037
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1038
1039
1040
1041
1042
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1043
    mlp_activations: Sequence[str] = ("relu",)
1044
    use_bias: bool = False
1045
1046
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
1047
    layernorm_type: str = "layernorm"
1048
    layernorm_epsilon: float = 1e-6
1049
    zero_centered_gamma: bool = False
1050
1051
    output_layernorm: bool = False
    drop_path: float = 0.0
1052
    enable_rotary_pos_emb: bool = False
1053
    rotary_pos_emb_group_method: str = "consecutive"
1054
    fuse_qkv_params: bool = True
1055
1056
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1057
1058
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1059

zlsh80826's avatar
zlsh80826 committed
1060
1061
1062
1063
1064
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1065
1066
    @nn.compact
    def __call__(self, inputs, encoder_mask=None, deterministic=False):
1067
1068
1069
1070
1071
1072
1073
        # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
        encoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            encoder_mask,
            self.window_size,
        )

1074
1075
1076
1077
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1078
1079
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1080
1081
1082
1083
1084
1085
1086
1087
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1088
1089
1090
            else:
                rel_emb = self.relative_embedding
            encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
1091
        else:
1092
            encoder_bias = None
1093
1094
1095
1096
1097
1098

        # Attention block.
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1099
1100
1101
1102
1103
1104
1105
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_attention_layer_norm",
            )(inputs)
1106
1107
1108
1109
1110
1111
1112

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            fuse_qkv=self.fuse_qkv_params,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            use_bias=self.use_bias,
            name="attention",
        )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1132
1133
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1134
1135
1136
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1137
1138
1139
1140
        x = x + residual

        # MLP block.
        residual = x
1141
1142
1143
1144
1145
1146
1147
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(x)
1148
1149
1150
1151
1152
1153
1154
1155
1156

        if self.apply_residual_connection_post_layernorm:
            residual = y

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        y = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1157
1158
1159
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1160
1161
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1162
            name="mlp",
1163
        )(y, deterministic=deterministic)
1164

1165
1166
1167
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1168

1169
1170
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
1171
1172
1173
            y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                y, deterministic=deterministic
            )
1174
1175
1176
        y = y + residual

        if self.output_layernorm:
1177
1178
1179
1180
1181
1182
1183
            y = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(y)
1184

1185
        assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
1186
1187
1188
1189
1190
        return y


class DecoderLayer(nn.Module):
    """Transformer decoder layer that attends to the encoder."""
1191

1192
    enable_relative_embedding: bool = True
1193
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1194
1195
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1196
    head_dim: int = 64
1197
1198
1199
1200
1201
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1202
1203
1204
1205
1206
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1207
    mlp_activations: Sequence[str] = ("relu",)
1208
    use_bias: bool = False
1209
1210
1211
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
1212
    layernorm_type: str = "layernorm"
1213
    layernorm_epsilon: float = 1e-6
1214
    zero_centered_gamma: bool = False
1215
    drop_path: float = 0.0
1216
    enable_rotary_pos_emb: bool = False
1217
    rotary_pos_emb_group_method: str = "consecutive"
1218
    fuse_qkv_params: bool = True
1219
1220
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1221
1222
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1223

zlsh80826's avatar
zlsh80826 committed
1224
1225
1226
1227
1228
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1229
    @nn.compact
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
    def __call__(
        self,
        inputs,
        encoded,
        decoder_mask=None,
        encoder_decoder_mask=None,
        deterministic=False,
        decode=False,
        max_decode_length=None,
    ):
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
        decoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            decoder_mask,
            self.window_size,
        )

        encoder_decoder_mask = apply_swa_mask(
            "padding",
            encoder_decoder_mask,
            self.window_size,
        )

1252
1253
1254
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim
1255
1256
1257
1258

        if self.enable_relative_embedding:
            l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
            if self.relative_embedding is None:
1259
1260
1261
1262
1263
1264
1265
1266
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1267
1268
1269
            else:
                rel_emb = self.relative_embedding
            decoder_bias = rel_emb(l, l, False)
1270
        else:
1271
            decoder_bias = None
1272
1273
1274
1275
1276
1277

        # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1278
1279
1280
1281
1282
1283
1284
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_self_attention_layer_norm",
            )(inputs)
1285
1286
1287
1288
1289
1290
1291

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # Self-attention block
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="self_attention",
        )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1311
1312
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1313
1314
1315
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1316
1317
1318
1319
        x = x + residual

        # Encoder-Decoder block.
        residual = x
1320
1321
1322
1323
1324
1325
1326
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_cross_attention_layer_norm",
        )(x)
1327
1328
1329

        if self.apply_residual_connection_post_layernorm:
            residual = y
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
        y = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="encoder_decoder_attention",
        )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1349
1350
1351
1352
        y = y + residual

        # MLP block.
        residual = y
1353
1354
1355
1356
1357
1358
1359
        z = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(y)
1360
1361
1362
1363
1364
1365
        if self.apply_residual_connection_post_layernorm:
            residual = z
        z = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1366
1367
1368
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1369
1370
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1371
            name="mlp",
1372
        )(z, deterministic=deterministic)
1373
1374
1375
        z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            z, deterministic=deterministic
        )
1376
1377
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
1378
1379
1380
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
1381
1382
1383
        z = z + residual

        if self.output_layernorm:
1384
1385
1386
1387
1388
1389
1390
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
1391

1392
        assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
1393
1394
1395
        return z


1396
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1397
1398
1399
    """
    Generate causal mask
    """
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
    shape = (batch, seqlen)
    idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)

    mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


def make_self_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1410
1411
1412
    """
    Generate attention mask
    """
1413
1414
1415
1416
1417
1418
1419
    shape = (batch, seqlen)
    mask = jnp.ones((*shape, shape[-1]))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
def assert_allclose(
    actual: Array,
    desired: Array,
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
    dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
    **kwargs,
) -> None:
    """Check if two tensors are close.

    Args:
      actual: test tensor.
      desired: reference tensor.
      dtype: data type or data type name (default: inferred from
        `actual`).
      rtol: relative tolerance (default: based on `dtype`).
      atol: absolute tolerance (default: based on `dtype`).
      **kwargs: keyword arguments to pass to np.testing.assert_allclose.
    """

    # Infer data type if needed
    if dtype is None:
        if isinstance(actual, float):
            dtype = "float32"
        else:
            dtype = actual.dtype

    # Determine tolerances
zlsh80826's avatar
zlsh80826 committed
1448
    tols = {}
1449
1450
1451
1452
1453
1454
1455
1456
    if rtol is None or atol is None:
        tols = dtype_tols(dtype)
    if rtol is not None:
        tols["rtol"] = rtol
    if atol is not None:
        tols["atol"] = atol

    # Cast tensors to fp32
1457
1458
1459
1460
    if not isinstance(actual, float):
        actual = actual.astype(jnp.float32)
    if not isinstance(desired, float):
        desired = desired.astype(jnp.float32)
1461
1462
1463
1464
1465

    # Check if tensors are close
    np.testing.assert_allclose(actual, desired, **tols, **kwargs)


1466
1467
1468
1469
def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
    flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
    flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)

1470
1471
1472
    for (expected_path, expected_value), (actual_path, actual_value) in zip(
        flatten_expected, flatten_actual
    ):
1473
1474
        assert expected_path == actual_path
        key_str = jax.tree_util.keystr(expected_path)
1475
1476
1477
1478
1479
1480
1481
        assert_allclose(
            expected_value,
            actual_value,
            rtol=rtol,
            atol=atol,
            err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
        )
1482
1483


1484
1485
1486
def dtype_tols(
    dtype: Union[DType, TEDType, np.dtype],
    reference_value: float = 1.0,
1487
1488
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
1489
1490
1491
1492
1493
1494
) -> Dict[str, float]:
    """Expected numerical tolerance for a data type.

    Args:
      dtype: data type.
      reference_value: reference value (default: 1).
1495
1496
      rtol: override for relative tolerance estimate
      atol: override for absolute tolerance estimate
1497
1498
1499
1500
1501
1502

    Returns:
      Dictionary with "rtol" and "atol" as keys

    """

1503
1504
1505
1506
    # Return immediately if tolerances are fully specified
    if rtol is not None and atol is not None:
        return {"rtol": rtol, "atol": atol}

1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
    # Convert to JAX dtype if needed
    if isinstance(dtype, TEDType):
        dtype = {
            TEDType.kByte: jnp.uint8,
            TEDType.kInt32: jnp.int32,
            TEDType.kInt64: jnp.int64,
            TEDType.kFloat32: jnp.float32,
            TEDType.kFloat16: jnp.float16,
            TEDType.kBFloat16: jnp.bfloat16,
            TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
            TEDType.kFloat8E5M2: jnp.float8_e5m2,
        }[dtype]
    elif isinstance(dtype, np.dtype):
        dtype = jnp.dtype(dtype)

    # Expect bit-wise accuracy for integer dtypes
    if not jnp.issubdtype(dtype, jnp.floating):
1524
1525
1526
1527
1528
        if rtol is None:
            rtol = 0.0
        if atol is None:
            atol = 0.0
        return {"rtol": rtol, "atol": atol}
1529
1530
1531

    # Estimate floating-point error
    finfo = jnp.finfo(dtype)
1532
    eps_relaxed = math.pow(finfo.eps, 2 / 3)
1533
1534
1535
1536
1537
1538
1539
1540
    with jax.default_device(jax.devices("cpu")[0]):
        if isinstance(reference_value, (float, int)):
            reference_value = jnp.array(reference_value, dtype=dtype)
        else:
            reference_value = reference_value.astype(dtype)
        spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
        spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
        ulp = max(spacing_high.item(), spacing_low.item())
1541
1542
1543
1544
1545
    if rtol is None:
        rtol = eps_relaxed
    if atol is None:
        atol = max(ulp, eps_relaxed)
    return {"rtol": rtol, "atol": atol}
1546
1547


1548
def sync_params_values(dst, src, transformations, sep="/"):
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
    """
    This function will reconstuct a tree with dst's tree_def/shape and src's value.
    transformations is a map that records the key mappings between dst and src.
    If no dst key found in the transformerations, it will fall back to src key = dst key.
    transformations = {
        dst key map 0: src key map 0,
        dst key map 1: src key map 1,
        ...
        # if dst key = src key, we don't need to add it
    }
    """
    src_values = {}
    for key, value in jax.tree_util.tree_leaves_with_path(src):
        normalized_key = sep.join(x.key for x in key)
        src_values[normalized_key] = value

    flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
    synced_dst_values = []

    for key, value in flatten_dst:
        normalized_key = sep.join(x.key for x in key)
        if normalized_key in transformations:
            corresponding_src_key = transformations[normalized_key]
        else:
            corresponding_src_key = normalized_key
        synced_dst_values.append(src_values[corresponding_src_key])

    synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)

    return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598


@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
    if NVTE_DEBUG_NUMERICS:
        args = [
            jnp.mean(tensor),
            jnp.min(tensor),
            jnp.max(tensor),
            jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
            jnp.count_nonzero(tensor),
        ]
        fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"

        if hist:
            h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
            args += [h[0], h[1]]
            fmt = fmt + "\n  {}\n  {}"

        jax.debug.print(fmt, *args)