You need to sign in or sign up before continuing.
utils.py 61.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
zlsh80826's avatar
zlsh80826 committed
4
"""Utility for the TE layer tests"""
5

Alp Dener's avatar
Alp Dener committed
6
import os
7
import functools
8
import math
9
import operator
Alp Dener's avatar
Alp Dener committed
10
11
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType
from contextlib import contextmanager
12

13
import jax
14
15
import jax.numpy as jnp
import numpy as np
16
from flax import linen as nn
17
from flax.linen.attention import combine_masks
18
19
20
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
21
import pytest
22

23
24
25
26
from transformer_engine.jax.attention import (
    canonicalize_attn_mask_type,
    make_swa_mask,
)
27
from transformer_engine.jax.quantize.helper import DType as TEDType
28

29
30
PRNGKey = Any
Shape = Tuple[int, ...]
Alp Dener's avatar
Alp Dener committed
31
32
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
33
34
35
PrecisionLike = Union[
    None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
36
37
Initializer = Callable[[PRNGKey, Shape, DType], Array]

38
39
40
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))

41

42
def is_devices_enough(required):
zlsh80826's avatar
zlsh80826 committed
43
44
45
    """
    Check if the available GPUs is enough
    """
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    return len(jax.devices()) >= required


def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
    # Generate broadcast dims for drop_path.
    drop_path_shape = list(range(0, len(shape)))
    drop_path_shape.pop(batch_dim)
    return drop_path_shape


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    return (x,)


def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
    """Convert a string to an activation function."""
69
    if fn_or_string == "linear":
70
71
72
73
74
75
76
77
78
79
80
        return lambda x: x
    if isinstance(fn_or_string, str):
        return getattr(nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


def combine_biases(*masks: Optional[Array]):
    """Combine attention biases.

81
82
    Args:
      *masks: set of attention bias arguments to combine, some can be None.
83

84
85
86
    Returns:
      Combined mask, reduced by summation, returns None if no masks given.
    """
87
88
89
    masks = [m for m in masks if m is not None]
    if not masks:
        return None
90
91
92
    assert all(
        map(lambda x: x.ndim == masks[0].ndim, masks)
    ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
93
94
95
96
97
98
    mask, *other_masks = masks
    for other_mask in other_masks:
        mask = mask + other_mask
    return mask


99
def get_parameters_for_test_level(param_dict: dict):
100
101
    """
    Takes an input dictionary of parameters keyed by test type "L0", etc.
102
    Returns the parameters for the test level specified in the environment variable
103
104
105
106
107
    """
    DEFAULT_TEST_LEVEL = "L0"
    test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
    if test_level not in param_dict:
        raise ValueError("Unsupported test level")
108
    return param_dict[test_level]
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141


def value_to_test_name_str(value):
    """Converts a value to how it should appear in a test name."""
    if isinstance(value, tuple) or isinstance(value, list):
        return "_".join([value_to_test_name_str(v) for v in value])

    dtype_type = type(jnp.float32)
    if isinstance(value, dtype_type):
        return value.dtype

    return str(value)


def value_to_named_param(value, id_prefix: str = ""):
    param_type = type(pytest.param(0))
    if isinstance(value, param_type):
        return value

    x = pytest.param(value, id=f"{id_prefix}_{value_to_test_name_str(value)}")
    return x


def values_to_named_params(params, id_prefix: str = ""):
    return [value_to_named_param(v, id_prefix=id_prefix) for v in params]


def pytest_parametrize_wrapper(param_name, param_values):
    """
    A wrapper for pytest.mark.parametrize to allow for automatic
    naming of tests based on the parameter values.
    """
    if isinstance(param_values, dict):
142
143
144
145
146
147
148
149
150
151
        # If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
        # unwrap the selected level before proceeding.
        param_values = get_parameters_for_test_level(param_values)

    if "," not in param_name:
        # Multi-parameterize annotations are not supported in this wrapper
        # and are just a passthrough to default pytest.mark.parametrize.
        # E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
        # will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
        param_values = values_to_named_params(param_values, id_prefix=param_name)
152
153
154
155
156
157
158

    def decorator(func):
        return pytest.mark.parametrize(param_name, param_values)(func)

    return decorator


159
160
161
class DotProductAttention(nn.Module):
    transpose_batch_sequence: bool = True
    scale_attn_logits: bool = True
162
    dropout_rate: float = 0.0
163
164
    dtype: DType = jnp.float32
    float32_logits: bool = False
165
166
    """Computes dot-product attention given query, key, and value.

167
168
169
    This is the core function for applying attention based on
    https://arxiv.org/abs/1706.03762. It calculates the attention weights given
    query and key and combines the values using the attention weights.
170

171
172
    Args:
        dropout_rate: dropout rate
173
        dtype: the data type used to allocate the initial parameters (default: float32).
174
175
176
        float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
    """
177

178
    @nn.compact
179
180
181
182
183
184
185
186
    def __call__(
        self,
        query: Array,
        key: Array,
        value: Array,
        bias: Optional[Array] = None,
        deterministic: bool = False,
    ):
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        """
        Args:
            query: queries for calculating attention with shape of `[batch, q_length,
            num_heads, qk_depth_per_head]`.
            key: keys for calculating attention with shape of `[batch, kv_length,
            num_gqa_groups, qk_depth_per_head]`.
            value: values to be used in attention with shape of `[batch, kv_length,
            num_gqa_groups, v_depth_per_head]`.
            bias: bias for the attention weights. This should be broadcastable to the
            shape `[batch, num_heads, q_length, kv_length]` This can be used for
            incorporating causal masks, padding masks, proximity bias, etc.
            dropout_rng: JAX PRNGKey: to be used for dropout
            deterministic: bool, deterministic or not (to apply dropout)
        Returns:
            Output of shape `[batch, length, num_heads, v_depth_per_head]`.
        """
203
        input_dtype = query.dtype
204
        assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
205
        batch_dim = 1 if self.transpose_batch_sequence else 0
206
207
208
        assert (
            query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
        ), "q, k, v batch dims must match."
209
        sequence_dim = 0 if self.transpose_batch_sequence else 1
210
211
212
        assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
        assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
        assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
213
214
215

        if self.scale_attn_logits:
            head_dim = query.shape[-1]
216
            depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            query = query / depth_scaling

        # Casting logits and softmax computation for float32 for model stability.
        if self.float32_logits:
            query = query.astype(jnp.float32)
            key = key.astype(jnp.float32)

        # `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
        h_q, h_kv = query.shape[-2], key.shape[-2]
        assert (h_q % h_kv == 0) and (h_q >= h_kv)
        group_size = h_q // h_kv
        grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))

        if self.transpose_batch_sequence:
231
            attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
232
        else:
233
            attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
234
235
236
237
238
239
240
241
242
243
244

        # reshape back to normal DPA shape for bias/softmax/dropout
        b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
        attn_weights_without_groups_shape = (b, h * g, q, k)
        attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)

        # Apply attention bias: masking, dropout, proximity bias, etc.
        if bias is not None:
            attn_weights = attn_weights + bias.astype(attn_weights.dtype)

        # Normalize the attention weights across `kv_length` dimension.
245
        attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
246
247

        # Apply attention dropout.
248
        if not deterministic and self.dropout_rate > 0.0:
249
250
251
252
            keep_prob = 1.0 - self.dropout_rate
            # T5 broadcasts along the "length" dim, but unclear which one that
            # corresponds to in positional dimensions here, assuming query dim.
            dropout_shape = list(attn_weights.shape)
253
            dropout_rng = self.make_rng("dropout")
254
            keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
255
            multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
256
257
258
            attn_weights = attn_weights * multiplier

        attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
259
        # attn_weights = attn_weights.astype(input_dtype)
260
261
262

        # Take the linear combination of `value`.
        if self.transpose_batch_sequence:
263
            return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
264

265
266
267
268
        assert (
            attn_weights.dtype == input_dtype
        ), f"input.dtype={input_dtype}, output.dtype={attn_weights.dtype}"

269
        return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
270
271
272
273
274


class DenseGeneral(nn.Module):
    """A linear transformation with flexible axes and FP8 support.

275
276
277
    Attributes:
    features: tuple with numbers of output features.
    axis: tuple with axes to apply the transformation on.
278
    dtype: the data type used to allocate the initial parameters (default: float32).
279
280
281
    kernel_init: initializer function for the weight matrix.
    use_bias: whether to add a bias to the output (default: False).
    bias_init: initializer function for the bias vector.
282
    """
283

284
285
286
287
288
289
290
291
292
293
294
    features: Union[Iterable[int], int]
    axis: Union[Iterable[int], int] = -1
    dtype: DType = jnp.float32
    kernel_init: Initializer = None
    kernel_axes: Tuple[str, ...] = ()
    use_bias: bool = False
    bias_init: Initializer = nn.initializers.zeros
    bias_axes: Tuple[str, ...] = ()

    def __post_init__(self):
        if self.kernel_init is None:
295
296
297
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
298
299
300
301
302
303
304
305
306
307
308
309
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

        Args:
        inputs: The nd-array to be transformed.

        Returns:
        The transformed input.
        """
310
        input_dtype = inputs.dtype
311
312
313
314
315
316
317
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
        kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
318
319
320
321
322
        kernel = self.param(
            "kernel",
            nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
            kernel_param_shape,
            self.dtype,
323
        )
324

325
        kernel = jnp.asarray(kernel, input_dtype)
326
327
328
        kernel = jnp.reshape(kernel, kernel_shape)

        if self.use_bias:
329
330
331
332
333
            bias = self.param(
                "bias",
                nn.with_logical_partitioning(self.bias_init, self.bias_axes),
                self.features,
                self.dtype,
334
            )
335
            bias = bias.astype(input_dtype)
336
337
338
339
340
        else:
            bias = None

        contract_ind = tuple(range(0, len(axis)))

341
342
343
        y = lax.dot_general(
            inputs, kernel, ((axis, contract_ind), ((), ())), preferred_element_type=input_dtype
        )
344
345
346

        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
347
348

        assert y.dtype == inputs.dtype, f"input.dtype={inputs.dtype}, output.dtype={y.dtype}"
349
350
351
352
353
354
        return y


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block.

355
356
357
358
359
360
361
    Attributes:
      intermediate_dim: Shared dimension of hidden layers.
      activations: Type of activations for each layer.  Each element is either
        'linear', a string function name in flax.linen, or a function.
      kernel_init: Kernel function, passed to the dense layers.
      deterministic: Whether the dropout layers should be deterministic.
      intermediate_dropout_rate: Dropout rate used after the intermediate layers.
362
      dtype: the data type used to allocate the initial parameters (default: float32).
363
364
    """

365
366
    transpose_batch_sequence: bool
    intermediate_dim: int = 2048
367
    activations: Sequence[Union[str, Callable]] = ("relu",)
368
369
    kernel_init: Initializer = None
    intermediate_dropout_rate: float = 0.1
370
371
    intermediate_dropout_dims: Sequence[int] = ()
    use_bias: bool = False
372
    dtype: Any = jnp.float32
373
    fuse_wi: bool = True
374
375
376

    def __post_init__(self):
        if self.kernel_init is None:
377
378
379
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "truncated_normal", dtype=self.dtype
            )
380
381
382
383
384
385
386
387
388
389
        super().__post_init__()

    @nn.compact
    def __call__(self, inputs, deterministic: bool = False):
        """Applies Transformer MlpBlock module."""
        # Iterate over specified MLP input activation functions.
        # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.

        activations = []
        if self.fuse_wi:
390
            dense_name = "wi"
391
            num_activations = len(self.activations)
392
393
394
395
396
397
398
399
400
            x = DenseGeneral(
                self.intermediate_dim * num_activations,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                kernel_axes=("embed", "mlp"),
                use_bias=self.use_bias,
                bias_axes="mlp",
                name=dense_name,
            )(inputs)
401
402
403
404
405
406
            x = jnp.split(x, num_activations, axis=-1)
            for idx, act_fn in enumerate(self.activations):
                x_i = _convert_to_activation_function(act_fn)(x[idx])
                activations.append(x_i)
        else:
            for idx, act_fn in enumerate(self.activations):
407
408
409
410
411
412
413
414
415
416
                dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
                x = DenseGeneral(
                    self.intermediate_dim,
                    dtype=self.dtype,
                    kernel_init=self.kernel_init,
                    kernel_axes=("embed", "mlp"),
                    use_bias=self.use_bias,
                    bias_axes="mlp",
                    name=dense_name,
                )(inputs)
417
418
419
420
421
422
                x = _convert_to_activation_function(act_fn)(x)
                activations.append(x)

        # Take elementwise product of above intermediate activations.
        x = functools.reduce(operator.mul, activations)
        # Apply dropout and final dense output projection.
423
424
425
426
427
        x = nn.Dropout(
            rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
        )(
            x, deterministic=deterministic
        )  # Broadcast along length.
428

429
        if self.transpose_batch_sequence:
430
            x = nn.with_logical_constraint(x, ("length", "batch", "mlp"))
431
        else:
432
            x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
433
434
435
436
437
438
439
440
441
        output = DenseGeneral(
            inputs.shape[-1],
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            kernel_axes=("mlp", "embed"),
            use_bias=self.use_bias,
            bias_axes="embed",
            name="wo",
        )(x)
442

443
444
445
        assert (
            output.dtype == inputs.dtype
        ), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
446
447
448
        return output


449
def apply_rotary_pos_emb_alternate(
450
451
452
453
454
455
456
457
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
458
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
459
460
461
462
463
464
465
466
467
468
    timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    first_half, second_half = jnp.split(inputs, 2, axis=-1)
    first_part = first_half * cos - second_half * sin
    second_part = second_half * cos + first_half * sin
    first_part = first_part.astype(inputs.dtype)
    second_part = second_part.astype(inputs.dtype)
469
    return jnp.concatenate([first_part, second_part], axis=-1).astype(inputs.dtype)
470
471


472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
def apply_rotary_pos_emb_consecutive(
    inputs: jnp.ndarray,
    position: jnp.ndarray,
    min_timescale: int = 1,
    max_timescale: int = 10000,
):
    embedding_dim = inputs.shape[-1]
    half_embedding_dim = embedding_dim // 2
    fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim

    inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
    inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1)
    inputs_shifted = jax.lax.select(
        jnp.tile(
            jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2),
            inputs.shape[:-1] + (1,),
        ),
        inputs_shifted_right,
        inputs_shifted_left,
    )
    fraction = jnp.repeat(fraction, 2)
493
    timescale = min_timescale * (max_timescale / min_timescale) ** fraction
494
495
496
497
498
499
500
501
502

    position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))

    sinusoid_inp = position / timescale
    sin = jnp.sin(sinusoid_inp)
    cos = jnp.cos(sinusoid_inp)
    sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
    outputs = inputs * cos + inputs_shifted * sin * sign

503
    return outputs.astype(inputs.dtype)
504
505


506
507
508
509
510
511
512
513
514
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))


class MultiHeadAttention(nn.Module):
    """Multi-head dot-product attention.

    Attributes:
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
zlsh80826's avatar
zlsh80826 committed
515
      num_gqa_groups: number of kv attention heads
516
      head_dim: dimension of each head.
517
      dtype: the data type used to allocate the initial parameters (default: float32).
518
519
520
521
      dropout_rate: dropout rate
      kernel_init: initializer for the kernel of the Dense layers.
      float32_logits: bool, if True then compute logits in float32 to avoid
        numerical issues with bfloat16.
522
    """
523

zlsh80826's avatar
zlsh80826 committed
524
525
526
527
    num_heads: int = 8
    num_gqa_groups: int | None = None
    head_dim: int = 64
    transpose_batch_sequence: bool = True
528
    dtype: DType = jnp.float32
529
    dropout_rate: float = 0.0
530
    kernel_init: Initializer = None
531
    float32_logits: bool = False  # computes logits in float32 for stability.
532
533
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
534
    enable_rotary_pos_emb: bool = False
535
    rotary_pos_emb_group_method: str = "consecutive"
536
    fuse_qkv: bool = True
537
    use_bias: bool = False
538
539
540

    def __post_init__(self):
        if self.kernel_init is None:
541
542
543
            self.kernel_init = nn.initializers.variance_scaling(
                1.0, "fan_in", "normal", dtype=self.dtype
            )
zlsh80826's avatar
zlsh80826 committed
544
545
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
546
547
548
        super().__post_init__()

    @nn.compact
549
550
551
552
553
554
555
556
557
558
    def __call__(
        self,
        inputs_q: Array,
        inputs_kv: Array,
        mask: Optional[Array] = None,
        bias: Optional[Array] = None,
        *,
        decode: bool = False,
        deterministic: bool = False,
    ) -> Array:
559
560
        """Applies multi-head dot product attention on the input data.

561
562
        Projects the inputs into multi-headed query, key, and value vectors,
        applies dot-product attention and project the results to an output vector.
563

564
565
566
567
568
569
        There are two modes: decoding and non-decoding (e.g., training). The mode is
        determined by `decode` argument. For decoding, this method is called twice,
        first to initialize the cache and then for an actual decoding process. The
        two calls are differentiated by the presence of 'cached_key' in the variable
        dict. In the cache initialization stage, the cache variables are initialized
        as zeros and will be filled in the subsequent decoding process.
570

571
572
573
574
        In the cache initialization call, `inputs_q` has a shape [batch, length,
        q_features] and `inputs_kv`: [batch, length, kv_features]. During the
        incremental decoding stage, query, key and value all have the shape [batch,
        1, qkv_features] corresponding to a single step.
575

576
577
578
579
580
581
582
        Args:
          inputs_q: input queries of shape `[batch, q_length, q_features]`.
          inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
          mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
          bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
          decode: Whether to prepare and use an autoregressive cache.
          deterministic: Disables dropout if set to True.
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        Returns:
          output of shape `[batch, length, q_features]`.
        """
        q_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_heads * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )

        kv_projection = functools.partial(
            DenseGeneral,
            axis=-1,
            features=self.num_gqa_groups * self.head_dim,
            kernel_axes=("embed", "joined_kv"),
            use_bias=self.use_bias,
            bias_axes="joined_kv",
            dtype=self.dtype,
        )
606
607
608
609
610

        # NOTE: T5 does not explicitly rescale the attention logits by
        #       1/sqrt(depth_kq)!  This is folded into the initializers of the
        #       linear transformations, which is equivalent under Adafactor
        depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
611
612
613
        query_init = lambda *args: self.kernel_init(*args) / (
            depth_scaling if self.scaled_query_init else 1.0
        )
614
615
616
617
618
619
620
621
622
623
624
625

        # Project inputs_q to multi-headed q/k/v
        # dimensions are then [batch, length, num_heads, head_dim]

        def qkv_init(key, shape, dtype):
            assert shape[-1] % 3 == 0

            q_shape = (shape[0], shape[1] // 3)
            k_shape = (shape[0], shape[1] // 3)
            v_shape = (shape[0], shape[1] // 3)

            q_kernel = query_init(key, q_shape, dtype)
zlsh80826's avatar
zlsh80826 committed
626
627
            k_kernel = self.kernel_init(key, k_shape, dtype)
            v_kernel = self.kernel_init(key, v_shape, dtype)
628
629
630

            return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)

631
632
633
        is_self_attn = inputs_q is inputs_kv
        is_gqa = self.num_heads != self.num_gqa_groups
        is_qkvpack = is_self_attn and not is_gqa
zlsh80826's avatar
zlsh80826 committed
634

635
        if self.fuse_qkv:
zlsh80826's avatar
zlsh80826 committed
636
            if is_qkvpack:
637

638
639
640
641
642
643
644
645
646
647
                qkv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_heads * self.head_dim * 3,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=qkv_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="qkv",
                    dtype=self.dtype,
                )(inputs_kv)
648

649
                query, key, value = jnp.split(
650
651
652
653
                    qkv_proj,
                    [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
                    axis=-1,
                )
654

655
            else:
656
657
658
659
660
661
662
663
664
665
666
667
                query = q_projection(kernel_init=query_init, name="query")(inputs_q)

                kv_proj = DenseGeneral(
                    axis=-1,
                    features=self.num_gqa_groups * self.head_dim * 2,
                    kernel_axes=("embed", "joined_kv"),
                    kernel_init=self.kernel_init,
                    use_bias=self.use_bias,
                    bias_axes="joined_kv",
                    name="kv",
                    dtype=self.dtype,
                )(inputs_kv)
zlsh80826's avatar
zlsh80826 committed
668
                key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
669
        else:
670
671
672
            query = q_projection(kernel_init=query_init, name="query")(inputs_q)
            key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
            value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
673

674
675
676
677
        if self.enable_rotary_pos_emb:
            batch_dim = 1 if self.transpose_batch_sequence else 0
            seq_dim = 1 - batch_dim

678
679
            q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
            k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
680

681
            if self.rotary_pos_emb_group_method == "alternate":
682
683
684
685
                apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
            else:
                apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive

686
687
688
689
            query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
            key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
            query = apply_rotary_pos_emb(query, q_position)
            key = apply_rotary_pos_emb(key, k_position)
690

zlsh80826's avatar
zlsh80826 committed
691
692
693
        query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
        key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
        value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
694
695

        if self.transpose_batch_sequence:
696
697
698
            query = nn.with_logical_constraint(query, ("length", "batch", "heads", "kv"))
            key = nn.with_logical_constraint(key, ("length", "batch", "heads", "kv"))
            value = nn.with_logical_constraint(value, ("length", "batch", "heads", "kv"))
699
        else:
700
701
702
            query = nn.with_logical_constraint(query, ("batch", "length", "heads", "kv"))
            key = nn.with_logical_constraint(key, ("batch", "length", "heads", "kv"))
            value = nn.with_logical_constraint(value, ("batch", "length", "heads", "kv"))
703
704
705

        if decode:
            # Detect if we're initializing by absence of existing cache data.
706
            is_initialized = self.has_variable("cache", "cached_key")
707
708
709
710
711
            # The key and value have dimension [batch, length, num_heads, head_dim],
            # but we cache them as [batch, num_heads, head_dim, length] as a TPU
            # fusion optimization. This also enables the "scatter via one-hot
            # broadcast" trick, which means we do a one-hot broadcast instead of a
            # scatter/gather operations, resulting in a 3-4x speedup in practice.
zlsh80826's avatar
zlsh80826 committed
712
            swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
713
714
715
716
717
718
719
720
721
            cached_key = self.variable(
                "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
            )
            cached_value = self.variable(
                "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
            )
            cache_index = self.variable(
                "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
            )
722
723
724
725
726
727
728
729
            if is_initialized:
                batch, num_heads, head_dim, length = cached_key.value.shape
                # During fast autoregressive decoding, we feed one position at a time,
                # and cache the keys and values step by step.
                # Sanity shape check of cached key against input query.
                expected_shape = (batch, 1, num_heads, head_dim)
                if expected_shape != query.shape:
                    raise ValueError(
730
731
732
                        "Autoregressive cache shape error, "
                        f"expected query shape {expected_shape} instead got {query.shape}."
                    )
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759

                # Create a OHE of the current index. NOTE: the index is increased below.
                cur_index = cache_index.value
                one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
                # In order to update the key, value caches with the current key and
                # value, we move the length axis to the back, similar to what we did for
                # the cached ones above.
                # Note these are currently the key and value of a single position, since
                # we feed one position at a time.
                one_token_key = jnp.moveaxis(key, -3, -1)
                one_token_value = jnp.moveaxis(value, -3, -1)
                # Update key, value caches with our new 1d spatial slices.
                # We implement an efficient scatter into the cache via one-hot
                # broadcast and addition.
                key = cached_key.value + one_token_key * one_hot_indices
                value = cached_value.value + one_token_value * one_hot_indices
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # Move the keys and values back to their original shapes.
                key = jnp.moveaxis(key, -1, -3)
                value = jnp.moveaxis(value, -1, -3)

                # Causal mask for cached decoder self-attention: our single query
                # position should only attend to those key positions that have already
                # been generated and cached, not the remaining zero elements.
                mask = combine_masks(
760
                    jnp.logical_not(mask),
761
762
                    jnp.broadcast_to(
                        jnp.arange(length) <= cur_index,
763
764
765
766
767
768
769
                        # (1, 1, length) represent (head dim, query length, key length)
                        # query length is 1 because during decoding we deal with one
                        # index.
                        # The same mask is applied to all batch elements and heads.
                        (batch, 1, 1, length),
                    ),
                )
770
771
772
773
774
775
776

                # Grab the correct relative attention bias during decoding. This is
                # only required during single step decoding.
                if bias is not None:
                    # The bias is a full attention matrix, but during decoding we only
                    # have to take a slice of it.
                    # This is equivalent to bias[..., cur_index:cur_index+1, :].
777
778
779
                    bias = dynamic_vector_slice_in_dim(
                        jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
                    )
780
781
782
783

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
784

785
786
787
788
789
            attention_bias = lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.0).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype),
            )
790
791
792
793
794
795
796
797
        else:
            attention_bias = None

        # Add provided bias term (e.g. relative position embedding).
        if bias is not None:
            attention_bias = combine_biases(attention_bias, bias)

        # Apply attention.
798
799
800
801
802
803
804
        x = DotProductAttention(
            transpose_batch_sequence=self.transpose_batch_sequence,
            scale_attn_logits=self.scale_attn_logits,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            float32_logits=self.float32_logits,
        )(query, key, value, bias=attention_bias, deterministic=deterministic)
805
806
807
808

        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

        if self.transpose_batch_sequence:
809
            x = nn.with_logical_constraint(x, ("length", "batch", "joined_kv"))
810
        else:
811
            x = nn.with_logical_constraint(x, ("batch", "length", "joined_kv"))
812
813

        # Back to the original inputs dimensions.
814

815
        out = DenseGeneral(
816
            features=inputs_q.shape[-1],  # output dim is set to the input dim.
817
818
            axis=-1,
            kernel_init=self.kernel_init,
819
            kernel_axes=("joined_kv", "embed"),
820
            use_bias=self.use_bias,
821
            bias_axes="embed",
822
            dtype=self.dtype,
823
824
            name="out",
        )(x)
825

826
827
828
        assert (
            inputs_q.dtype == inputs_kv.dtype == out.dtype
        ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
829
830
831
832
833
        return out


class LayerNorm(nn.Module):
    """T5 Layer normalization operating on the last axis of the input data."""
834

835
836
    epsilon: float = 1e-6
    dtype: Any = jnp.float32
837
    layernorm_type: str = "layernorm"
838
839
    zero_centered_gamma: bool = False
    scale_init: Initializer = None
840
841
    bias_init: Initializer = nn.initializers.zeros

842
843
844
845
846
847
848
849
    def __post_init__(self):
        if self.scale_init is None:
            if not self.zero_centered_gamma:
                self.scale_init = nn.initializers.ones
            else:
                self.scale_init = nn.initializers.zeros
        super().__post_init__()

850
851
852
853
    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Applies layer normalization on the input."""

854
        input_dtype = x.dtype
855
856
        features = x.shape[-1]

857
858
859
860
861
        scale = self.param(
            "scale",
            nn.with_logical_partitioning(self.scale_init, ("embed",)),
            (features,),
            self.dtype,
862
        )
863
        x_ = x.astype(jnp.float32)
864
        if self.layernorm_type == "layernorm":
865
866
867
            mean = jnp.mean(x_, axis=-1, keepdims=True)
            var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
            y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
868

869
870
871
872
873
            bias = self.param(
                "ln_bias",
                nn.with_logical_partitioning(self.bias_init, ("embed",)),
                (features,),
                self.dtype,
874
            )
875
            bias = jnp.asarray(bias, input_dtype)
876

877
878
879
            if not self.zero_centered_gamma:
                z = y * scale + bias
            else:
880
                z = y * (scale + 1.0) + bias
881
        else:
882
            assert self.layernorm_type == "rmsnorm"
883
            assert not self.zero_centered_gamma
884
885
            mean2 = jnp.mean(lax.square(x_), axis=-1, keepdims=True)
            y = x_ * lax.rsqrt(mean2 + self.epsilon)
886
            z = y * scale
887
        z = z.astype(input_dtype)
888

889
890
        assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
        return z
891
892
893
894
895


class RelativePositionBiases(nn.Module):
    """Adds T5-style relative positional embeddings to the attention logits.

896
897
898
899
900
901
902
    Attributes:
      num_buckets: Number of buckets to bucket distances between key and query
        positions into.
      max_distance: Maximum distance before everything is lumped into the last
        distance bucket.
      num_heads: Number of heads in the attention layer. Each head will get a
        different relative position weighting.
903
      dtype: the data type used to allocate the initial parameters (default: float32).
904
905
906
      embedding_init: initializer for relative embedding table.
    """

907
908
909
910
911
912
913
    num_buckets: int
    max_distance: int
    num_heads: int
    dtype: Any
    embedding_init: Callable[..., Array] = nn.linear.default_embed_init

    @staticmethod
914
915
916
    def _relative_position_bucket(
        relative_position, bidirectional=True, num_buckets=32, max_distance=128
    ):
917
918
        """Translate relative position to a bucket number for relative attention.

919
920
921
922
923
924
925
926
927
928
        The relative position is defined as memory_position - query_position, i.e.
        the distance in tokens from the attending position to the attended-to
        position.  If bidirectional=False, then positive relative positions are
        invalid.
        We use smaller buckets for small absolute relative_position and larger
        buckets for larger absolute relative_positions.  All relative
        positions >=max_distance  map to the same bucket.  All relative
        positions <=-max_distance map to the same bucket.  This should allow for
        more graceful generalization to longer sequences than the model has been
        trained on.
929

930
931
932
933
934
        Args:
          relative_position: an int32 array
          bidirectional: a boolean - whether the attention is bidirectional
          num_buckets: an integer
          max_distance: an integer
935

936
937
938
939
        Returns:
          a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
940
941
942
943
944
945
946
947
948
949
950
951
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).astype(np.int32) * num_buckets
            n = np.abs(n)
        else:
            n = np.maximum(n, 0)
        # now n is in the range [0, inf)
        max_exact = num_buckets // 2
        is_small = n < max_exact
        val_if_large = max_exact + (
952
953
954
955
            np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
            / np.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).astype(np.int32)
956
957
958
959
960
961
962
963
        val_if_large = np.minimum(val_if_large, num_buckets - 1)
        ret += np.where(is_small, n, val_if_large)
        return ret

    @nn.compact
    def __call__(self, qlen, klen, bidirectional=True):
        """Produce relative position embedding attention biases.

964
965
966
967
968
        Args:
          qlen: attention query length.
          klen: attention key length.
          bidirectional: whether to allow positive memory-query relative position
            embeddings.
969

970
971
972
        Returns:
          output: `(1, len, q_len, k_len)` attention bias
        """
973
974
        context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
        memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
975
976
977
978
979
980
981
        relative_position = memory_position - context_position  # shape (qlen, klen)
        rp_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
        )
982
        relative_attention_bias = self.param(
983
            "rel_embedding",
984
            nn.with_logical_partitioning(self.embedding_init, ("heads", "relpos_buckets")),
985
            (self.num_heads, self.num_buckets),
986
            jnp.float32,
987
        )
988
989
990
991
992
993
994
995
996
997
998
999
1000

        relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
        # Instead of using a slow gather, we create a leading-dimension one-hot
        # array from rp_bucket and use it to perform the gather-equivalent via a
        # contraction, i.e.:
        # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
        # This is equivalent to relative_attention_bias[:, rp_bucket]
        bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
        rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
        # --> shape (qlen, klen, num_heads)
        values = lax.dot_general(
            relative_attention_bias,
            rp_bucket_one_hot,
1001
1002
            (((1,), (0,)), ((), ())),  # rhs, lhs contracting dims
        )  # no batched dims
1003
1004
1005
1006
1007
        # Add a singleton batch dimension.
        # --> shape (1, num_heads, qlen, klen)
        return values[jnp.newaxis, ...]


1008
1009
1010
1011
1012
1013
def apply_swa_mask(
    attn_mask_type: str,
    original_mask: Array,
    window_size: Tuple[int, int] = (-1, -1),
) -> Array:
    """Apply the sliding window mask to a given mask"""
1014
    _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
1015
    assert _attn_mask_type is not None
1016
    batch = original_mask.shape[0]
1017
1018
    max_seqlen_q = original_mask.shape[-2]
    max_seqlen_kv = original_mask.shape[-1]
1019
1020
1021
    pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
    pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
    swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
1022
    # In swa_mask and original_mask 0 is masked out
1023
    new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
1024
1025
1026
    return new_mask


1027
1028
class EncoderLayer(nn.Module):
    """Transformer encoder layer."""
1029

1030
    enable_relative_embedding: bool = True
1031
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1032
1033
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1034
    head_dim: int = 64
1035
1036
1037
1038
1039
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1040
1041
1042
1043
1044
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1045
    mlp_activations: Sequence[str] = ("relu",)
1046
    use_bias: bool = False
1047
1048
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
1049
    layernorm_type: str = "layernorm"
1050
    layernorm_epsilon: float = 1e-6
1051
    zero_centered_gamma: bool = False
1052
1053
    output_layernorm: bool = False
    drop_path: float = 0.0
1054
    enable_rotary_pos_emb: bool = False
1055
    rotary_pos_emb_group_method: str = "consecutive"
1056
    fuse_qkv_params: bool = True
1057
1058
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1059
1060
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1061

zlsh80826's avatar
zlsh80826 committed
1062
1063
1064
1065
1066
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1067
1068
    @nn.compact
    def __call__(self, inputs, encoder_mask=None, deterministic=False):
1069
1070
1071
1072
1073
1074
1075
        # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
        encoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            encoder_mask,
            self.window_size,
        )

1076
1077
1078
1079
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim

1080
1081
        if self.enable_relative_embedding:
            if self.relative_embedding is None:
1082
1083
1084
1085
1086
1087
1088
1089
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1090
1091
1092
            else:
                rel_emb = self.relative_embedding
            encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
1093
        else:
1094
            encoder_bias = None
1095
1096
1097
1098
1099
1100

        # Attention block.
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1101
1102
1103
1104
1105
1106
1107
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_attention_layer_norm",
            )(inputs)
1108
1109
1110
1111
1112
1113
1114

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            fuse_qkv=self.fuse_qkv_params,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            use_bias=self.use_bias,
            name="attention",
        )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1134
1135
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1136
1137
1138
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1139
1140
1141
1142
        x = x + residual

        # MLP block.
        residual = x
1143
1144
1145
1146
1147
1148
1149
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(x)
1150
1151
1152
1153
1154
1155
1156
1157
1158

        if self.apply_residual_connection_post_layernorm:
            residual = y

        # [batch, length, emb_dim] -> [batch, length, emb_dim]
        y = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1159
1160
1161
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1162
1163
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1164
            name="mlp",
1165
        )(y, deterministic=deterministic)
1166

1167
1168
1169
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1170

1171
1172
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
1173
1174
1175
            y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                y, deterministic=deterministic
            )
1176
1177
1178
        y = y + residual

        if self.output_layernorm:
1179
1180
1181
1182
1183
1184
1185
            y = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(y)
1186

1187
        assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
1188
1189
1190
1191
1192
        return y


class DecoderLayer(nn.Module):
    """Transformer decoder layer that attends to the encoder."""
1193

1194
    enable_relative_embedding: bool = True
1195
    relative_embedding: nn.Module = None
zlsh80826's avatar
zlsh80826 committed
1196
1197
    num_attention_heads: int = 8
    num_gqa_groups: int | None = None
1198
    head_dim: int = 64
1199
1200
1201
1202
1203
    hidden_dropout: float = 0.1
    hidden_dropout_dims: Sequence[int] = ()
    attention_dropout: float = 0.1
    intermediate_dropout: float = 0.1
    intermediate_dropout_dims: Sequence[int] = ()
1204
1205
1206
1207
1208
    transpose_batch_sequence: bool = True
    float32_attention_logits: bool = False
    scale_attn_logits: bool = False
    scaled_query_init: bool = True
    mlp_dim: int = 2048
1209
    mlp_activations: Sequence[str] = ("relu",)
1210
    use_bias: bool = False
1211
1212
1213
    dtype: Any = jnp.float32
    apply_residual_connection_post_layernorm: bool = False
    output_layernorm: bool = False
1214
    layernorm_type: str = "layernorm"
1215
    layernorm_epsilon: float = 1e-6
1216
    zero_centered_gamma: bool = False
1217
    drop_path: float = 0.0
1218
    enable_rotary_pos_emb: bool = False
1219
    rotary_pos_emb_group_method: str = "consecutive"
1220
    fuse_qkv_params: bool = True
1221
1222
    fuse_mlp_wi: bool = True
    self_attn_bias_type: Any = None
1223
1224
    self_attn_mask_type: str = "no_mask"
    window_size: Tuple[int, int] = (-1, -1)
1225

zlsh80826's avatar
zlsh80826 committed
1226
1227
1228
1229
1230
    def __post_init__(self):
        if self.num_gqa_groups is None:
            self.num_gqa_groups = self.num_attention_heads
        super().__post_init__()

1231
    @nn.compact
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
    def __call__(
        self,
        inputs,
        encoded,
        decoder_mask=None,
        encoder_decoder_mask=None,
        deterministic=False,
        decode=False,
        max_decode_length=None,
    ):
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        decoder_mask = apply_swa_mask(
            self.self_attn_mask_type,
            decoder_mask,
            self.window_size,
        )

        encoder_decoder_mask = apply_swa_mask(
            "padding",
            encoder_decoder_mask,
            self.window_size,
        )

1254
1255
1256
        # Relative position embedding as attention biases.
        sequence_dim = 0 if self.transpose_batch_sequence else 1
        batch_dim = 1 - sequence_dim
1257
1258
1259
1260

        if self.enable_relative_embedding:
            l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
            if self.relative_embedding is None:
1261
1262
1263
1264
1265
1266
1267
1268
                rel_emb = RelativePositionBiases(
                    num_buckets=32,
                    max_distance=128,
                    num_heads=self.num_attention_heads,
                    dtype=self.dtype,
                    embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
                    name="relpos_bias",
                )
1269
1270
1271
            else:
                rel_emb = self.relative_embedding
            decoder_bias = rel_emb(l, l, False)
1272
        else:
1273
            decoder_bias = None
1274
1275
1276
1277
1278
1279

        # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
        residual = inputs

        if not self.output_layernorm:
            # Attention block.
1280
1281
1282
1283
1284
1285
1286
            x = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="pre_self_attention_layer_norm",
            )(inputs)
1287
1288
1289
1290
1291
1292
1293

            if self.apply_residual_connection_post_layernorm:
                residual = x
        else:
            x = inputs

        # Self-attention block
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
        x = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="self_attention",
        )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
        x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            x, deterministic=deterministic
        )
1313
1314
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
1315
1316
1317
            x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                x, deterministic=deterministic
            )
1318
1319
1320
1321
        x = x + residual

        # Encoder-Decoder block.
        residual = x
1322
1323
1324
1325
1326
1327
1328
        y = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_cross_attention_layer_norm",
        )(x)
1329
1330
1331

        if self.apply_residual_connection_post_layernorm:
            residual = y
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
        y = MultiHeadAttention(
            num_heads=self.num_attention_heads,
            num_gqa_groups=self.num_gqa_groups,
            dtype=self.dtype,
            head_dim=self.head_dim,
            transpose_batch_sequence=self.transpose_batch_sequence,
            dropout_rate=self.attention_dropout,
            float32_logits=self.float32_attention_logits,
            scale_attn_logits=self.scale_attn_logits,
            scaled_query_init=self.scaled_query_init,
            enable_rotary_pos_emb=self.enable_rotary_pos_emb,
            rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
            fuse_qkv=self.fuse_qkv_params,
            use_bias=self.use_bias,
            name="encoder_decoder_attention",
        )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
        y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            y, deterministic=deterministic
        )
1351
1352
1353
1354
        y = y + residual

        # MLP block.
        residual = y
1355
1356
1357
1358
1359
1360
1361
        z = LayerNorm(
            layernorm_type=self.layernorm_type,
            epsilon=self.layernorm_epsilon,
            zero_centered_gamma=self.zero_centered_gamma,
            dtype=self.dtype,
            name="pre_mlp_layer_norm",
        )(y)
1362
1363
1364
1365
1366
1367
        if self.apply_residual_connection_post_layernorm:
            residual = z
        z = MlpBlock(
            transpose_batch_sequence=self.transpose_batch_sequence,
            intermediate_dim=self.mlp_dim,
            activations=self.mlp_activations,
1368
1369
1370
            intermediate_dropout_rate=self.intermediate_dropout,
            intermediate_dropout_dims=self.intermediate_dropout_dims,
            use_bias=self.use_bias,
1371
1372
            dtype=self.dtype,
            fuse_wi=self.fuse_mlp_wi,
1373
            name="mlp",
1374
        )(z, deterministic=deterministic)
1375
1376
1377
        z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
            z, deterministic=deterministic
        )
1378
1379
        if self.drop_path > 0.0:
            drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
1380
1381
1382
            z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
                z, deterministic=deterministic
            )
1383
1384
1385
        z = z + residual

        if self.output_layernorm:
1386
1387
1388
1389
1390
1391
1392
            z = LayerNorm(
                layernorm_type=self.layernorm_type,
                epsilon=self.layernorm_epsilon,
                zero_centered_gamma=self.zero_centered_gamma,
                dtype=self.dtype,
                name="output_layernorm",
            )(z)
1393

1394
        assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
1395
1396
1397
        return z


1398
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1399
1400
1401
    """
    Generate causal mask
    """
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
    shape = (batch, seqlen)
    idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)

    mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


def make_self_mask(batch, seqlen, dtype=jnp.uint8):
zlsh80826's avatar
zlsh80826 committed
1412
1413
1414
    """
    Generate attention mask
    """
1415
1416
1417
1418
1419
1420
1421
    shape = (batch, seqlen)
    mask = jnp.ones((*shape, shape[-1]))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = 1 - mask
    return mask.astype(dtype)


1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
def assert_allclose(
    actual: Array,
    desired: Array,
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
    dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
    **kwargs,
) -> None:
    """Check if two tensors are close.

    Args:
      actual: test tensor.
      desired: reference tensor.
      dtype: data type or data type name (default: inferred from
        `actual`).
      rtol: relative tolerance (default: based on `dtype`).
      atol: absolute tolerance (default: based on `dtype`).
      **kwargs: keyword arguments to pass to np.testing.assert_allclose.
    """

    # Infer data type if needed
    if dtype is None:
        if isinstance(actual, float):
            dtype = "float32"
        else:
            dtype = actual.dtype

    # Determine tolerances
zlsh80826's avatar
zlsh80826 committed
1450
    tols = {}
1451
1452
1453
1454
1455
1456
1457
1458
    if rtol is None or atol is None:
        tols = dtype_tols(dtype)
    if rtol is not None:
        tols["rtol"] = rtol
    if atol is not None:
        tols["atol"] = atol

    # Cast tensors to fp32
1459
1460
1461
1462
    if not isinstance(actual, float):
        actual = actual.astype(jnp.float32)
    if not isinstance(desired, float):
        desired = desired.astype(jnp.float32)
1463
1464
1465
1466
1467

    # Check if tensors are close
    np.testing.assert_allclose(actual, desired, **tols, **kwargs)


1468
1469
1470
1471
def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
    flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
    flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)

1472
1473
1474
    for (expected_path, expected_value), (actual_path, actual_value) in zip(
        flatten_expected, flatten_actual
    ):
1475
1476
        assert expected_path == actual_path
        key_str = jax.tree_util.keystr(expected_path)
1477
1478
1479
1480
1481
1482
1483
        assert_allclose(
            expected_value,
            actual_value,
            rtol=rtol,
            atol=atol,
            err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
        )
1484
1485


1486
1487
1488
def dtype_tols(
    dtype: Union[DType, TEDType, np.dtype],
    reference_value: float = 1.0,
1489
1490
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
1491
1492
1493
1494
1495
1496
) -> Dict[str, float]:
    """Expected numerical tolerance for a data type.

    Args:
      dtype: data type.
      reference_value: reference value (default: 1).
1497
1498
      rtol: override for relative tolerance estimate
      atol: override for absolute tolerance estimate
1499
1500
1501
1502
1503
1504

    Returns:
      Dictionary with "rtol" and "atol" as keys

    """

1505
1506
1507
1508
    # Return immediately if tolerances are fully specified
    if rtol is not None and atol is not None:
        return {"rtol": rtol, "atol": atol}

1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
    # Convert to JAX dtype if needed
    if isinstance(dtype, TEDType):
        dtype = {
            TEDType.kByte: jnp.uint8,
            TEDType.kInt32: jnp.int32,
            TEDType.kInt64: jnp.int64,
            TEDType.kFloat32: jnp.float32,
            TEDType.kFloat16: jnp.float16,
            TEDType.kBFloat16: jnp.bfloat16,
            TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
            TEDType.kFloat8E5M2: jnp.float8_e5m2,
        }[dtype]
    elif isinstance(dtype, np.dtype):
Alp Dener's avatar
Alp Dener committed
1522
        dtype = DType(dtype)
1523
1524
1525

    # Expect bit-wise accuracy for integer dtypes
    if not jnp.issubdtype(dtype, jnp.floating):
1526
1527
1528
1529
1530
        if rtol is None:
            rtol = 0.0
        if atol is None:
            atol = 0.0
        return {"rtol": rtol, "atol": atol}
1531
1532
1533

    # Estimate floating-point error
    finfo = jnp.finfo(dtype)
1534
    eps_relaxed = math.pow(finfo.eps, 2 / 3)
1535
1536
1537
1538
1539
1540
1541
1542
    with jax.default_device(jax.devices("cpu")[0]):
        if isinstance(reference_value, (float, int)):
            reference_value = jnp.array(reference_value, dtype=dtype)
        else:
            reference_value = reference_value.astype(dtype)
        spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
        spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
        ulp = max(spacing_high.item(), spacing_low.item())
1543
1544
1545
1546
1547
    if rtol is None:
        rtol = eps_relaxed
    if atol is None:
        atol = max(ulp, eps_relaxed)
    return {"rtol": rtol, "atol": atol}
1548
1549


1550
def sync_params_values(dst, src, transformations, sep="/"):
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
    """
    This function will reconstuct a tree with dst's tree_def/shape and src's value.
    transformations is a map that records the key mappings between dst and src.
    If no dst key found in the transformerations, it will fall back to src key = dst key.
    transformations = {
        dst key map 0: src key map 0,
        dst key map 1: src key map 1,
        ...
        # if dst key = src key, we don't need to add it
    }
    """
    src_values = {}
    for key, value in jax.tree_util.tree_leaves_with_path(src):
1564
1565
        # Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
        normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
1566
1567
1568
1569
1570
1571
        src_values[normalized_key] = value

    flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
    synced_dst_values = []

    for key, value in flatten_dst:
1572
1573
        # Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
        normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
1574
1575
1576
1577
1578
1579
1580
1581
1582
        if normalized_key in transformations:
            corresponding_src_key = transformations[normalized_key]
        else:
            corresponding_src_key = normalized_key
        synced_dst_values.append(src_values[corresponding_src_key])

    synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)

    return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602


@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
    if NVTE_DEBUG_NUMERICS:
        args = [
            jnp.mean(tensor),
            jnp.min(tensor),
            jnp.max(tensor),
            jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
            jnp.count_nonzero(tensor),
        ]
        fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"

        if hist:
            h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
            args += [h[0], h[1]]
            fmt = fmt + "\n  {}\n  {}"

        jax.debug.print(fmt, *args)
Alp Dener's avatar
Alp Dener committed
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619


@contextmanager
def use_jax_gemm(enabled=False):
    orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)

    try:
        if enabled:
            os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
        yield

    finally:
        if enabled:
            if orig_custom_calls_filter is None:
                os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
            else:
                os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter