test_fused_attn.py 55.8 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
import os
6
from enum import Enum, auto
7
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
8
from functools import partial
9
from math import sqrt
10
from typing import Tuple, Optional, Dict
11
import random
12
13
14

import jax
import jax.numpy as jnp
15
import numpy as np
16
17
18
19
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
20
21
from flax.linen.dtypes import promote_dtype
from jax import Array
22
from jax import value_and_grad, jit
23
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
24
from jax.typing import ArrayLike, DTypeLike
25

26
from transformer_engine.jax import autocast
27
from transformer_engine.jax.sharding import MeshResource
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
31
    AttnSoftmaxType,
32
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
33
    QKVFormat,
34
35
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
36
    fused_attn,
37
    run_length_fill,
38
    make_swa_mask,
39
    SequenceDescriptor,
40
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
41
    ReorderStrategy,
42
)
43
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
44
from transformer_engine_jax import (
45
46
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
47
    get_device_compute_capability,
48
)
49

50
51
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
52

53
54
55
# Get determinism
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

56

57
@pytest.fixture(autouse=True, scope="module")
58
def init():
59
    """
60
    WAR for CUDA uninitialize error
61
    """
zlsh80826's avatar
zlsh80826 committed
62
63
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
64
65
66
    yield


67
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
68
69
70
71
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
72
    softmax_offset: Optional[ArrayLike],
73
74
75
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
76
    softmax_type: AttnSoftmaxType,
77
78
79
80
81
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
82
    """
zlsh80826's avatar
zlsh80826 committed
83
    Similar to flax.linen.dot_product_attention but with GQA support
84
    """
zlsh80826's avatar
zlsh80826 committed
85
86
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
87

zlsh80826's avatar
zlsh80826 committed
88
    b, s_q, h_q, d = query.shape
89
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
90
91
92
93
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
94
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
95
96

    if bias is not None:
97
98
99
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
100
        logits = logits + bias
101
102
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
103
104
105
106

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
107
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    match softmax_type:
        case AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_out = jax.nn.softmax(logits).astype(dtype)
        case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            # Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
            # Append a zero logit, apply standard softmax, then remove last column
            zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
            logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # Append learnable offset logit, apply standard softmax, then remove last column
            learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
            learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
            logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case _:
            raise NotImplementedError(f"Unknown {softmax_type=}")
128

129
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
130
131
132
133
134
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

135
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
136
137
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
138
    return context
139
140


141
142
143
144
145
146
147
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
148
    """
149
150
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
151
    If segment_pos is not provided, aragne of the segment_ids will be applied.
152
    """
153
154
155
156
157
158
159
160
161
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
162
    return inv_causal_mask
163

164

165
@partial(jax.jit, static_argnums=(4, 5))
166
def make_mask(
167
168
169
170
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
171
    attn_mask_type: AttnMaskType,
172
    window_size: Optional[Tuple[int, int]] = None,
173
) -> Array:
174
175
176
177
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
178
179
180
181
182
183
184
185

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
186
    """
187
    # segment masks
188
    inv_mask = make_attention_mask(
189
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
190
    )
191
192
193
194
195
196
197
198
199
200

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

201
202
203
204
205
206
207
208
209
210
    if attn_mask_type.is_bottom_right():
        run_length_out_q = run_length_fill(segment_ids_q)
        run_length_out_kv = run_length_fill(segment_ids_kv)
        bottom_right_causal_mask = make_attention_mask(
            run_length_out_q - segment_pos_q,
            run_length_out_kv - segment_pos_kv,
            jnp.less_equal,
        )
        inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
    elif attn_mask_type.is_causal():
211
212
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
213
        )
214
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
215

216
    # sliding window mask
217
218
219
220
221
222
223
224
225
226
227
228
    inv_swa_mask = (
        make_swa_mask(
            segment_pos_q,
            segment_pos_kv,
            window_size,
            dtype=jnp.bool,
            segment_ids_q=segment_ids_q,
            segment_ids_kv=segment_ids_kv,
        )
        if attn_mask_type.is_bottom_right()
        else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
    )
229
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
230
231
    mask = jnp.logical_not(inv_mask)
    return mask
232

233

234
235
@jax.jit
def get_seqlens_and_offsets(segment_ids):
236
237
238
239
240
241
242
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
243
        first_column = x[..., :1] != 0
244
245
246
247
248
249
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
250
251
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
252
253
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
254
255
256
257
258
259
260
261
262
263
264
265


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


266
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
267
    """
zlsh80826's avatar
zlsh80826 committed
268
    JAX native dot product attention implementation
269
    """
270
271
272
273
    output = general_dot_product_attention(
        query,
        key,
        value,
274
        softmax_offset,
275
276
        bias,
        mask,
277
278
279
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
280
        softmax_type=kwargs["softmax_type"],
281
282
283
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
284
    return output.astype(query.dtype)
285
286


287
288
289
290
291
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
292
    softmax_offset,
293
    sequence_descriptor,
294
295
296
    dropout_rng,
    **kwargs,
):
297
    """
zlsh80826's avatar
zlsh80826 committed
298
    TE customcall dot product attention implementation
299
    """
300
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
301
    match qkv_layout:
302
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
303
304
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
305
306
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
307
308
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
309
310
311
312
313
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
314
315
316
    return fused_attn(
        qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
317
318


319
class BiasShape(Enum):
320
321
322
323
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

324
325
326
327
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
328
329


330
331
332
333
334
335
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
336
337
@dataclass
class FusedAttnRunner:
338
    """
zlsh80826's avatar
zlsh80826 committed
339
    Fused attention runner
340
    """
341

zlsh80826's avatar
zlsh80826 committed
342
343
344
345
346
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
347
348
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
349
350
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
351
    softmax_type: AttnSoftmaxType
zlsh80826's avatar
zlsh80826 committed
352
353
354
355
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
356
    bias_shape: BiasShape
357
358
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
359
360
    stripe_size: int | None = None
    num_segments_per_seq: int | None = None
zlsh80826's avatar
zlsh80826 committed
361

362
363
364
365
366
367
368
369
370
371
372
373
374
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

375
376
377
378
379
380
381
382
    def __post_init__(self):
        # Reset defaults for num_segments_per_seq if not explicitly passed
        if self.num_segments_per_seq is None:
            if self.qkv_layout.is_thd():
                self.num_segments_per_seq = 2
            else:
                self.num_segments_per_seq = 1

383
384
385
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
386
387
388
389
390
391
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
392
        else:
393
            return 1
394

zlsh80826's avatar
zlsh80826 committed
395
    def _check_configs(self):
396
        # TODO(rewang): probably adds this in is_fused_attn_available
397
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
398
399
            pytest.skip("THD format requires padding masks.")

400
401
402
403
404
405
406
407
408
409
        if self.attn_mask_type.is_bottom_right():
            if self.max_seqlen_q > self.max_seqlen_kv:
                pytest.skip(
                    f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
                )
            if self.attn_bias_type is not AttnBiasType.NO_BIAS:
                pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
            if self.dropout_prob != 0.0:
                pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")

410
        if self.qkv_layout.is_qkvpacked():
411
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
412
413
414
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
415

416
417
418
419
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

        if get_device_compute_capability(0) >= 100 and self.is_training:
            if FusedAttnHelper.is_non_deterministic_allowed() and (
                (self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
                or get_cudnn_version() < 90700
            ):
                pytest.skip(
                    "For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
                    " dropout"
                )
            if not FusedAttnHelper.is_non_deterministic_allowed() and (
                self.dropout_prob != 0.0
                or self.attn_bias_type != AttnBiasType.NO_BIAS
                or get_cudnn_version() < 91801
            ):
                pytest.skip(
                    "For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
                    " dropout"
                )
439
440
441
442
443
444
445
446
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

447
        self.backend = FusedAttnHelper(
448
            self.is_training,
449
450
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
451
452
453
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
454
            self.softmax_type,
455
456
457
458
459
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
460
461
            self.head_dim_qk,
            self.head_dim_v,
462
            (-1, -1) if self.window_size is None else self.window_size,
463
        ).get_fused_attn_backend()
464
        if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
zlsh80826's avatar
zlsh80826 committed
465
            pytest.skip("Unsupported inputs combination or device compute capability.")
466

467
468
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
469
            and self.bias_shape != BiasShape._1HSS
470
        ):
471
            if self.attn_mask_type.is_padding():
472
                pytest.skip(
473
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
474
                )
475
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
476
477
478
479
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
480

zlsh80826's avatar
zlsh80826 committed
481
482
    def _setup_inputs(self):
        self._check_configs()
483
484
485
486
487
488

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
489
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
490

zlsh80826's avatar
zlsh80826 committed
491
        key = jax.random.PRNGKey(0)
492
        q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
493

494
495
496
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
497

498
499
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
500
        elif self.bias_shape == BiasShape._1HSS:
501
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
502
        elif self.bias_shape == BiasShape._B1SS:
503
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
504
        elif self.bias_shape == BiasShape._BHSS:
505
506
507
508
509
510
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
511
        elif self.bias_shape == BiasShape._11SS:
512
513
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
514
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
515

516
517
518
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
519
520

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
521
            if self.bias_shape == BiasShape._1HSS:
522
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
523
524
525
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
526
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
527
528
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
529
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
530
531
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
532
533
534
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
535
536
        else:
            self.bias = None
537

538
        if self.attn_mask_type.is_padding():
539
            pad_ratio = 0.3
540
541
        else:
            pad_ratio = 0.0
542

543
544
545
546
547
548
549
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            self.softmax_offset = jax.random.uniform(
                softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
            )
        else:
            self.softmax_offset = None

zlsh80826's avatar
zlsh80826 committed
550
551
552
553
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
554
555
556
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
557
558
559
560
561
562
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
563
564
565
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
566
567
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
568
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
569
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
570
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
571
572
573
574
575
576
577

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

578
579
580
581
582
583
584
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
585
586
587
588
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
589
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
590
                    if with_segment_pad:
591
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
592
593
594
595
596
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

597
598
599
600
601
602
603
604
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
605
606
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
607
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
608
609
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
610
611
612
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
613
            else:
614
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
615
616
617
618
619
                min_segment_len = None
                if (
                    self.window_size is not None or self.attn_mask_type.is_bottom_right()
                ):  # SWA or BRCM requires kv_len >= q_len
                    min_segment_len = self.seqlens_q
620
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
621
622
623
624
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
625
                    min_segment_len=min_segment_len,
626
                )
627
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
628
        else:
629
630
631
632
633
634
635
636
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
637

638
        # For reference code
639
        self.mask = make_mask(
640
641
642
643
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
644
            self.attn_mask_type,
645
            self.window_size,
646
        )
647

Reese Wang's avatar
Reese Wang committed
648
649
650
651
652
653
654
655
656
657
658
659
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
660
                stripe_size=self.stripe_size,
Reese Wang's avatar
Reese Wang committed
661
662
663
664
665
666
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
667
                stripe_size=self.stripe_size,
Reese Wang's avatar
Reese Wang committed
668
669
670
671
672
673
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

674
        # Test different input formats
675
        if self.qkv_layout.is_thd():
676
677
678
679
680
681
682
683
684
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
685
686
                    # Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
                    # if no CP and load balancing, else explicitly pass the segment_pos
687
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
688
689
690
691
692
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
693
694
695
696
697
698
699
700
701
702
                            (
                                self.cp_reorder_fn(self.segment_pos_q),
                                self.cp_reorder_fn(self.segment_pos_kv),
                            )
                            if self.cp_size > 1 and self.cp_load_balanced
                            else None
                        ),
                        is_thd=self.qkv_layout.is_thd(),
                        is_segment_ids_reordered=(
                            True if self.cp_size > 1 and self.cp_load_balanced else False
Reese Wang's avatar
Reese Wang committed
703
                        ),
704
705
706
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
707
        else:
708
709
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
710
711
712
713
714
715
716
717
718
719
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
720
721
722
723
724
725
726
727
728
729
730
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
731
732
                        is_thd=self.qkv_layout.is_thd(),
                        is_segment_ids_reordered=False,
733
734
735
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
736

zlsh80826's avatar
zlsh80826 committed
737
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
738
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
739

740
741
742
743
744
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
745
            self.mesh_resource.tpsp_resource,
746
747
748
749
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

750
        mask_pspec = PartitionSpec(
751
752
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
753
754
755
756
757
758
759
760
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
761
762
763
764
765
766
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
767
768
769
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
770
771
772

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
773
                None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
774
775
776
777
778
779
780
781
782
783
784
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

785
786
787
788
789
790
791
792
793
794
        # Softmax offset sharding (1, num_heads, 1, 1)
        # Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
        head_resource = (
            self.mesh_resource.tpsp_resource
            if self.mesh_resource.tpsp_resource is not None
            else self.mesh_resource.tp_resource
        )
        self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
        self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)

795
796
797
798
799
800
801
802
803
804
805
806
807
        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
808
809
    def test_forward(self):
        """
810
        Test forward with JITted primitive and unJITted reference
zlsh80826's avatar
zlsh80826 committed
811
812
        """
        self._setup_inputs()
813

814
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
815

816
        customcall_args = [
817
818
819
820
821
822
823
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
824
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
825
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
826
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
827
        ]
zlsh80826's avatar
zlsh80826 committed
828
        kwargs = {
829
830
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
831
            "softmax_type": self.softmax_type,
832
833
834
835
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
836
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
837
            "window_size": self.window_size,
838
839
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
840
            "stripe_size": self.stripe_size,
zlsh80826's avatar
zlsh80826 committed
841
        }
842

843
844
845
846
847
848
849
850
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
851
                self.softmax_offset_sharding,
852
                self.seq_desc_sharding,
853
854
855
856
                self.dropout_rng_sharding,
            ],
        )

857
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
858
859
860
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

861
        reference_out = jax_dpa(*args, **kwargs)
862

863
        if self.is_training and self.dropout_prob > 0.0:
864
865
            return

866
867
868
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
869
870
871

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
872

873
        if self.coll_count_ref is not None:
874
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
875
876
877
878
879
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
880
    def test_backward(self):
881
        """
882
883
884
885
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
886
        """
zlsh80826's avatar
zlsh80826 committed
887
888
889

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
890
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
891
            # Gradient is small, use a gradient multiplier to amplify the gradient
892
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
893
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
894
                gradient_multiplier /= 10
895
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
896
897
898
899
900
901
902
903
904
905
906
907
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
908
909
910
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
911

912
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
913
        customcall_args = [
914
915
916
917
918
919
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
920
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
921
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
922
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
923
        ]
924
        kwargs = {
925
926
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
927
            "softmax_type": self.softmax_type,
928
929
930
931
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
932
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
933
            "window_size": self.window_size,
934
935
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
936
            "stripe_size": self.stripe_size,
937
938
        }

939
        # We can compute dBias only for the [1, h, s, s] layout
940
941
942
943
944
945
946
947
948
949
950
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
951

952
953
954
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
955
956
957
958
959
960
961
962
963
964
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    customcall_fused_dpa,
                    q,
                    k,
                    v,
                    bias,
                    softmax_offset,
                    *args,
                    cp_reverse_out=True,
                    **kwargs,
965
966
                ),
                arg_nums,
967
968
969
970
971
972
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
973
                self.softmax_offset_sharding,
974
                self.seq_desc_sharding,
975
976
977
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
978
        )
979
980
        jitted_reference = jit(
            value_and_grad(
981
982
983
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
                ),
984
985
986
                arg_nums,
            )
        )
987

988
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
989
990
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
991
        reference_out, reference_dgrad = jitted_reference(*args)
992

zlsh80826's avatar
zlsh80826 committed
993
        # Skip elementwise comparison when dropout enabled
994
        if self.dropout_prob > 0.0:
995
996
            return

997
998
999
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
1000
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
1001

1002
        def check_dqkv(primitive, reference, pad, idx):
1003
1004
1005
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
1006

1007
1008
1009
1010
1011
1012
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

1013
1014
1015
1016
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

1017
1018
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
1019

1020
1021
1022
1023
1024
1025
1026
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
1027

1028
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
1029
1030
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

1031
1032
1033
1034
1035
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
1036

1037
            # Assert all masked dbias are 0s
1038
            assert_allclose(
1039
1040
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
1041
1042
                dtype=self.dtype,
            )
1043

1044
            # dbias padded part
1045
            assert_allclose(
1046
1047
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
1048
1049
                dtype=self.dtype,
            )
1050

1051
            # dbias valid part
1052
            assert_allclose(
1053
1054
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
1055
1056
1057
                dtype=self.dtype,
            )

1058
        if self.coll_count_ref is not None:
1059
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
1060
1061
1062
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

1063
1064
1065
1066
1067
1068
1069
1070

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
1071
1072
1073
        pytest.param(
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
        ),
1074
1075
    ],
)
1076
1077
1078
1079
1080
1081
1082
1083
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
        pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
        pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
    ],
)
1084
@pytest.mark.parametrize(
1085
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
1086
    [
1087
        # large data size + bf16 + qkv packed
1088
        pytest.param(
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
            2,
            2048,
            2048,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.BS3HD,
            id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED",
1099
1100
1101
        ),
        pytest.param(
            2,
1102
1103
            2048,
            2048,
1104
1105
1106
1107
1108
            12,
            12,
            64,
            64,
            jnp.bfloat16,
1109
1110
            QKVLayout.T3HD,
            id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED",
1111
        ),
1112
        # mid data size + bf16 + cross attn + kv packed
1113
        pytest.param(
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
            2,
            512,
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.BSHD_BS2HD,
            id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED",
1124
1125
        ),
        pytest.param(
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
            2,
            512,
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.THD_T2HD,
            id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED",
1136
        ),
1137
        # large data size + bf16 + cross attn + diff hidden v dim + qkv separate
1138
        pytest.param(
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
            2,
            2048,
            1024,
            12,
            12,
            64,
            32,
            jnp.bfloat16,
            QKVLayout.BSHD_BSHD_BSHD,
            id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE",
1149
        ),
1150
1151
1152
1153
1154
1155
1156
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
1157
            32,
1158
            jnp.bfloat16,
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
            QKVLayout.THD_THD_THD,
            id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE",
        ),
        # large data size + bf16 + gqa + kv packed
        pytest.param(
            2,
            2048,
            2048,
            12,
            6,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.BSHD_BS2HD,
            id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED",
        ),
        pytest.param(
            2,
            2048,
            2048,
            12,
            6,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.THD_T2HD,
            id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED",
1186
        ),
1187
        # small data size + fp16 + diff hidden v dim + qkv packed
1188
        pytest.param(
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
            4,
            128,
            128,
            16,
            16,
            64,
            32,
            jnp.float16,
            QKVLayout.BS3HD,
            id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED",
        ),
        pytest.param(
            4,
            128,
            128,
            16,
            16,
            64,
            32,
            jnp.float16,
            QKVLayout.T3HD,
            id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED",
        ),
        # small data size + fp16 + kv packed
        pytest.param(
            4,
            128,
            128,
            16,
            16,
            64,
            64,
            jnp.float16,
            QKVLayout.BSHD_BS2HD,
            id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED",
        ),
        pytest.param(
            4,
            128,
            128,
            16,
            16,
            64,
            64,
            jnp.float16,
            QKVLayout.THD_T2HD,
            id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED",
        ),
        # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
        pytest.param(
            2,
            1024,
            2048,
            12,
            6,
            128,
            64,
            jnp.float16,
            QKVLayout.BSHD_BSHD_BSHD,
            id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE",
        ),
        pytest.param(
            2,
            1024,
            2048,
            12,
            6,
            128,
            64,
            jnp.float16,
            QKVLayout.THD_THD_THD,
            id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE",
1261
        ),
1262
1263
1264
1265
1266
1267
1268
1269
1270
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
1271
1272
1273
1274
1275
1276
1277
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1278
1279
1280
1281
1282
1283
1284
1285
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
1286
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
zlsh80826's avatar
zlsh80826 committed
1287
1288
1289
1290
class TestFusedAttn:
    """
    Fused attention tester
    """
1291

zlsh80826's avatar
zlsh80826 committed
1292
    @staticmethod
1293
1294
1295
1296
1297
1298
1299
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1310
    def _test_forward(
1311
1312
1313
1314
1315
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1316
1317
        d_qk,
        d_v,
1318
1319
        attn_bias_type,
        attn_mask_type,
1320
        softmax_type,
1321
1322
1323
1324
1325
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1326
        swa,
1327
        seq_desc_format,
1328
    ):
1329
        """
zlsh80826's avatar
zlsh80826 committed
1330
        Test forward with parameterized configs
1331
1332
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1333
        """
1334
1335
1336
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1337
1338
1339
1340
1341
1342
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1343
1344
            d_qk,
            d_v,
1345
1346
            attn_bias_type,
            attn_mask_type,
1347
            softmax_type,
1348
1349
1350
1351
1352
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1353
            window_size,
1354
            seq_desc_format,
1355
        )
zlsh80826's avatar
zlsh80826 committed
1356
        runner.test_forward()
1357

zlsh80826's avatar
zlsh80826 committed
1358
    @staticmethod
1359
1360
1361
1362
1363
1364
1365
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1366
1367
1368
1369
1370
1371
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1372
1373
        d_qk,
        d_v,
1374
1375
        attn_bias_type,
        attn_mask_type,
1376
        softmax_type,
1377
1378
1379
1380
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1381
        swa,
1382
        seq_desc_format,
1383
    ):
zlsh80826's avatar
zlsh80826 committed
1384
1385
1386
        """
        Test backward with parameterized configs
        """
1387
1388
1389
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1390
1391
1392
1393
1394
1395
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1396
1397
            d_qk,
            d_v,
1398
1399
            attn_bias_type,
            attn_mask_type,
1400
            softmax_type,
1401
1402
1403
1404
1405
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1406
            window_size,
1407
            seq_desc_format,
1408
        )
zlsh80826's avatar
zlsh80826 committed
1409
        runner.test_backward()
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588


@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
        pytest.param(
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
        ),
    ],
)
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
    ],
)
@pytest.mark.parametrize(
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
    [
        # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
        pytest.param(
            2,
            1024,
            2048,
            12,
            6,
            128,
            64,
            jnp.bfloat16,
            QKVLayout.BSHD_BSHD_BSHD,
            id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
        ),
        pytest.param(
            2,
            1024,
            2048,
            12,
            6,
            128,
            64,
            jnp.bfloat16,
            QKVLayout.THD_THD_THD,
            id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
        ),
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
    ],
)
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
    ],
)
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
    ],
)
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
class TestFusedAttnWithDeterminism:
    """
    Fused attention tester with determinism
    """

    @staticmethod
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
        ],
    )
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    def _test_forward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d_qk,
        d_v,
        attn_bias_type,
        attn_mask_type,
        softmax_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
        swa,
        seq_desc_format,
    ):
        """
        Test forward with parameterized configs
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
        """
        TestFusedAttn._test_forward(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d_qk,
            d_v,
            attn_bias_type,
            attn_mask_type,
            softmax_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
            swa,
            seq_desc_format,
        )

    @staticmethod
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d_qk,
        d_v,
        attn_bias_type,
        attn_mask_type,
        softmax_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
        swa,
        seq_desc_format,
    ):
        """
        Test backward with parameterized configs
        """
        TestFusedAttn.test_backward(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d_qk,
            d_v,
            attn_bias_type,
            attn_mask_type,
            softmax_type,
            dropout_prob,
            dtype,
            qkv_layout,
            bias_shape,
            swa,
            seq_desc_format,
        )