test_fused_attn.py 47.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
from transformer_engine.jax import autocast
26
from transformer_engine.jax.sharding import MeshResource
27
28
29
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
30
    AttnSoftmaxType,
31
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
32
    QKVFormat,
33
34
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
35
    fused_attn,
36
    run_length_fill,
37
    make_swa_mask,
38
    SequenceDescriptor,
39
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
40
    ReorderStrategy,
41
)
42
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
43
from transformer_engine_jax import (
44
45
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
46
    get_device_compute_capability,
47
)
48

49
50
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
51

52

53
@pytest.fixture(autouse=True, scope="module")
54
def init():
55
    """
56
    WAR for CUDA uninitialize error
57
    """
zlsh80826's avatar
zlsh80826 committed
58
59
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
60
61
62
    yield


63
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
64
65
66
67
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
68
    softmax_offset: Optional[ArrayLike],
69
70
71
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
72
    softmax_type: AttnSoftmaxType,
73
74
75
76
77
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
78
    """
zlsh80826's avatar
zlsh80826 committed
79
    Similar to flax.linen.dot_product_attention but with GQA support
80
    """
zlsh80826's avatar
zlsh80826 committed
81
82
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
83

zlsh80826's avatar
zlsh80826 committed
84
    b, s_q, h_q, d = query.shape
85
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
86
87
88
89
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
90
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
91
92

    if bias is not None:
93
94
95
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
96
        logits = logits + bias
97
98
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
99
100
101
102

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
103
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    match softmax_type:
        case AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_out = jax.nn.softmax(logits).astype(dtype)
        case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            # Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
            # Append a zero logit, apply standard softmax, then remove last column
            zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
            logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # Append learnable offset logit, apply standard softmax, then remove last column
            learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
            learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
            logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case _:
            raise NotImplementedError(f"Unknown {softmax_type=}")
124

125
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
126
127
128
129
130
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

131
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
132
133
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
134
    return context
135
136


137
138
139
140
141
142
143
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
144
    """
145
146
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
147
    If segment_pos is not provided, aragne of the segment_ids will be applied.
148
    """
149
150
151
152
153
154
155
156
157
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
158
    return inv_causal_mask
159

160

161
@partial(jax.jit, static_argnums=(4, 5))
162
def make_mask(
163
164
165
166
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
167
    attn_mask_type: AttnMaskType,
168
    window_size: Optional[Tuple[int, int]] = None,
169
) -> Array:
170
171
172
173
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
174
175
176
177
178
179
180
181

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
182
    """
183
    # segment masks
184
    inv_mask = make_attention_mask(
185
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
186
    )
187
188
189
190
191
192
193
194
195
196

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

197
198
199
200
201
202
203
204
205
206
    if attn_mask_type.is_bottom_right():
        run_length_out_q = run_length_fill(segment_ids_q)
        run_length_out_kv = run_length_fill(segment_ids_kv)
        bottom_right_causal_mask = make_attention_mask(
            run_length_out_q - segment_pos_q,
            run_length_out_kv - segment_pos_kv,
            jnp.less_equal,
        )
        inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
    elif attn_mask_type.is_causal():
207
208
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
209
        )
210
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
211

212
    # sliding window mask
213
214
215
216
217
218
219
220
221
222
223
224
    inv_swa_mask = (
        make_swa_mask(
            segment_pos_q,
            segment_pos_kv,
            window_size,
            dtype=jnp.bool,
            segment_ids_q=segment_ids_q,
            segment_ids_kv=segment_ids_kv,
        )
        if attn_mask_type.is_bottom_right()
        else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
    )
225
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
226
227
    mask = jnp.logical_not(inv_mask)
    return mask
228

229

230
231
@jax.jit
def get_seqlens_and_offsets(segment_ids):
232
233
234
235
236
237
238
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
239
        first_column = x[..., :1] != 0
240
241
242
243
244
245
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
246
247
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
248
249
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
250
251
252
253
254
255
256
257
258
259
260
261


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


262
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
263
    """
zlsh80826's avatar
zlsh80826 committed
264
    JAX native dot product attention implementation
265
    """
266
267
268
269
    output = general_dot_product_attention(
        query,
        key,
        value,
270
        softmax_offset,
271
272
        bias,
        mask,
273
274
275
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
276
        softmax_type=kwargs["softmax_type"],
277
278
279
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
280
    return output.astype(query.dtype)
281
282


283
284
285
286
287
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
288
    softmax_offset,
289
    sequence_descriptor,
290
291
292
    dropout_rng,
    **kwargs,
):
293
    """
zlsh80826's avatar
zlsh80826 committed
294
    TE customcall dot product attention implementation
295
    """
296
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
297
    match qkv_layout:
298
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
299
300
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
301
302
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
303
304
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
305
306
307
308
309
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
310
311
312
    return fused_attn(
        qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
313
314


315
class BiasShape(Enum):
316
317
318
319
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

320
321
322
323
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
324
325


326
327
328
329
330
331
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
332
333
@dataclass
class FusedAttnRunner:
334
    """
zlsh80826's avatar
zlsh80826 committed
335
    Fused attention runner
336
    """
337

zlsh80826's avatar
zlsh80826 committed
338
339
340
341
342
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
343
344
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
345
346
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
347
    softmax_type: AttnSoftmaxType
zlsh80826's avatar
zlsh80826 committed
348
349
350
351
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
352
    bias_shape: BiasShape
353
354
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
355
356
    stripe_size: int | None = None
    num_segments_per_seq: int | None = None
zlsh80826's avatar
zlsh80826 committed
357

358
359
360
361
362
363
364
365
366
367
368
369
370
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

371
372
373
374
375
376
377
378
    def __post_init__(self):
        # Reset defaults for num_segments_per_seq if not explicitly passed
        if self.num_segments_per_seq is None:
            if self.qkv_layout.is_thd():
                self.num_segments_per_seq = 2
            else:
                self.num_segments_per_seq = 1

379
380
381
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
382
383
384
385
386
387
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
388
        else:
389
            return 1
390

zlsh80826's avatar
zlsh80826 committed
391
    def _check_configs(self):
392
        # TODO(rewang): probably adds this in is_fused_attn_available
393
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
394
395
            pytest.skip("THD format requires padding masks.")

396
397
398
399
400
401
402
403
404
405
        if self.attn_mask_type.is_bottom_right():
            if self.max_seqlen_q > self.max_seqlen_kv:
                pytest.skip(
                    f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
                )
            if self.attn_bias_type is not AttnBiasType.NO_BIAS:
                pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
            if self.dropout_prob != 0.0:
                pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")

406
        if self.qkv_layout.is_qkvpacked():
407
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
408
409
410
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
411

412
413
414
415
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )
416
        # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
417
        if (
418
            get_device_compute_capability(0) >= 100
419
420
421
422
            and self.dropout_prob == 0.1
            and self.attn_bias_type is not AttnBiasType.NO_BIAS
        ):
            pytest.skip(
423
                "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
424
            )
425
426
427
428
429
430
431
432
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

433
        self.backend = FusedAttnHelper(
434
            self.is_training,
435
436
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
437
438
439
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
440
            self.softmax_type,
441
442
443
444
445
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
446
447
            self.head_dim_qk,
            self.head_dim_v,
448
            (-1, -1) if self.window_size is None else self.window_size,
449
        ).get_fused_attn_backend()
450
        if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
zlsh80826's avatar
zlsh80826 committed
451
            pytest.skip("Unsupported inputs combination or device compute capability.")
452

453
454
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
455
            and self.bias_shape != BiasShape._1HSS
456
        ):
457
            if self.attn_mask_type.is_padding():
458
                pytest.skip(
459
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
460
                )
461
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
462
463
464
465
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
466

zlsh80826's avatar
zlsh80826 committed
467
468
    def _setup_inputs(self):
        self._check_configs()
469
470
471
472
473
474

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
475
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
476

zlsh80826's avatar
zlsh80826 committed
477
        key = jax.random.PRNGKey(0)
478
        q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
479

480
481
482
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
483

484
485
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
486
        elif self.bias_shape == BiasShape._1HSS:
487
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
488
        elif self.bias_shape == BiasShape._B1SS:
489
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
490
        elif self.bias_shape == BiasShape._BHSS:
491
492
493
494
495
496
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
497
        elif self.bias_shape == BiasShape._11SS:
498
499
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
500
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
501

502
503
504
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
505
506

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
507
            if self.bias_shape == BiasShape._1HSS:
508
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
509
510
511
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
512
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
513
514
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
515
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
516
517
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
518
519
520
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
521
522
        else:
            self.bias = None
523

524
        if self.attn_mask_type.is_padding():
525
            pad_ratio = 0.3
526
527
        else:
            pad_ratio = 0.0
528

529
530
531
532
533
534
535
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            self.softmax_offset = jax.random.uniform(
                softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
            )
        else:
            self.softmax_offset = None

zlsh80826's avatar
zlsh80826 committed
536
537
538
539
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
540
541
542
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
543
544
545
546
547
548
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
549
550
551
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
552
553
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
554
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
555
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
556
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
557
558
559
560
561
562
563

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

564
565
566
567
568
569
570
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
571
572
573
574
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
575
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
576
                    if with_segment_pad:
577
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
578
579
580
581
582
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

583
584
585
586
587
588
589
590
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
591
592
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
593
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
594
595
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
596
597
598
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
599
            else:
600
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
601
602
603
604
605
                min_segment_len = None
                if (
                    self.window_size is not None or self.attn_mask_type.is_bottom_right()
                ):  # SWA or BRCM requires kv_len >= q_len
                    min_segment_len = self.seqlens_q
606
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
607
608
609
610
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
611
                    min_segment_len=min_segment_len,
612
                )
613
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
614
        else:
615
616
617
618
619
620
621
622
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
623

624
        # For reference code
625
        self.mask = make_mask(
626
627
628
629
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
630
            self.attn_mask_type,
631
            self.window_size,
632
        )
633

Reese Wang's avatar
Reese Wang committed
634
635
636
637
638
639
640
641
642
643
644
645
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
646
                stripe_size=self.stripe_size,
Reese Wang's avatar
Reese Wang committed
647
648
649
650
651
652
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
653
                stripe_size=self.stripe_size,
Reese Wang's avatar
Reese Wang committed
654
655
656
657
658
659
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

660
        # Test different input formats
661
        if self.qkv_layout.is_thd():
662
663
664
665
666
667
668
669
670
671
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
672
673
674
675
676
677
678
679
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
                            self.cp_reorder_fn(self.segment_pos_q),
                            self.cp_reorder_fn(self.segment_pos_kv),
                        ),
680
681
682
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
683
        else:
684
685
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
686
687
688
689
690
691
692
693
694
695
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
696
697
698
699
700
701
702
703
704
705
706
707
708
709
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
710

zlsh80826's avatar
zlsh80826 committed
711
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
712
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
713

714
715
716
717
718
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
719
            self.mesh_resource.tpsp_resource,
720
721
722
723
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

724
        mask_pspec = PartitionSpec(
725
726
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
727
728
729
730
731
732
733
734
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
735
736
737
738
739
740
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
741
742
743
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
744
745
746

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
747
                None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
748
749
750
751
752
753
754
755
756
757
758
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

759
760
761
762
763
764
765
766
767
768
        # Softmax offset sharding (1, num_heads, 1, 1)
        # Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
        head_resource = (
            self.mesh_resource.tpsp_resource
            if self.mesh_resource.tpsp_resource is not None
            else self.mesh_resource.tp_resource
        )
        self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
        self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)

769
770
771
772
773
774
775
776
777
778
779
780
781
        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
782
783
    def test_forward(self):
        """
784
        Test forward with JITted primitive and unJITted reference
zlsh80826's avatar
zlsh80826 committed
785
786
        """
        self._setup_inputs()
787

788
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
789

790
        customcall_args = [
791
792
793
794
795
796
797
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
798
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
799
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
800
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
801
        ]
zlsh80826's avatar
zlsh80826 committed
802
        kwargs = {
803
804
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
805
            "softmax_type": self.softmax_type,
806
807
808
809
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
810
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
811
            "window_size": self.window_size,
812
813
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
814
            "stripe_size": self.stripe_size,
zlsh80826's avatar
zlsh80826 committed
815
        }
816

817
818
819
820
821
822
823
824
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
825
                self.softmax_offset_sharding,
826
                self.seq_desc_sharding,
827
828
829
830
                self.dropout_rng_sharding,
            ],
        )

831
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
832
833
834
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

835
        reference_out = jax_dpa(*args, **kwargs)
836

837
        if self.is_training and self.dropout_prob > 0.0:
838
839
            return

840
841
842
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
843
844
845

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
846

847
        if self.coll_count_ref is not None:
848
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
849
850
851
852
853
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
854
    def test_backward(self):
855
        """
856
857
858
859
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
860
        """
zlsh80826's avatar
zlsh80826 committed
861
862
863

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
864
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
865
            # Gradient is small, use a gradient multiplier to amplify the gradient
866
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
867
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
868
                gradient_multiplier /= 10
869
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
870
871
872
873
874
875
876
877
878
879
880
881
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
882
883
884
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
885

886
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
887
        customcall_args = [
888
889
890
891
892
893
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
894
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
895
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
896
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
897
        ]
898
        kwargs = {
899
900
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
901
            "softmax_type": self.softmax_type,
902
903
904
905
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
906
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
907
            "window_size": self.window_size,
908
909
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
910
            "stripe_size": self.stripe_size,
911
912
        }

913
        # We can compute dBias only for the [1, h, s, s] layout
914
915
916
917
918
919
920
921
922
923
924
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
925

926
927
928
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
929
930
931
932
933
934
935
936
937
938
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    customcall_fused_dpa,
                    q,
                    k,
                    v,
                    bias,
                    softmax_offset,
                    *args,
                    cp_reverse_out=True,
                    **kwargs,
939
940
                ),
                arg_nums,
941
942
943
944
945
946
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
947
                self.softmax_offset_sharding,
948
                self.seq_desc_sharding,
949
950
951
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
952
        )
953
954
        jitted_reference = jit(
            value_and_grad(
955
956
957
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
                ),
958
959
960
                arg_nums,
            )
        )
961

962
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
963
964
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
965
        reference_out, reference_dgrad = jitted_reference(*args)
966

zlsh80826's avatar
zlsh80826 committed
967
        # Skip elementwise comparison when dropout enabled
968
        if self.dropout_prob > 0.0:
969
970
            return

971
972
973
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
974
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
975

976
        def check_dqkv(primitive, reference, pad, idx):
977
978
979
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
980

981
982
983
984
985
986
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

987
988
989
990
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

991
992
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
993

994
995
996
997
998
999
1000
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
1001

1002
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
1003
1004
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

1005
1006
1007
1008
1009
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
1010

1011
            # Assert all masked dbias are 0s
1012
            assert_allclose(
1013
1014
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
1015
1016
                dtype=self.dtype,
            )
1017

1018
            # dbias padded part
1019
            assert_allclose(
1020
1021
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
1022
1023
                dtype=self.dtype,
            )
1024

1025
            # dbias valid part
1026
            assert_allclose(
1027
1028
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
1029
1030
1031
                dtype=self.dtype,
            )

1032
        if self.coll_count_ref is not None:
1033
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
1034
1035
1036
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

1037
1038
1039
1040
1041
1042
1043
1044

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
1045
1046
1047
        pytest.param(
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
        ),
1048
1049
    ],
)
1050
1051
1052
1053
1054
1055
1056
1057
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
        pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
        pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
    ],
)
1058
1059
1060
1061
1062
1063
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
1064
1065
1066
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
1067
1068
1069
    ],
)
@pytest.mark.parametrize(
1070
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
1071
    [
1072
1073
1074
1075
1076
        pytest.param(
            2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
        ),
        pytest.param(
            2,
1077
            512,
1078
1079
1080
1081
1082
1083
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
1084
            id="2-512-1024-12-12-64-64-BF16-CROSS",
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
        ),
1095
1096
1097
1098
1099
1100
1101
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
1102
            32,
1103
            jnp.bfloat16,
1104
1105
1106
1107
            id="2-2048-1024-12-12-64-32-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
1108
        ),
1109
1110
1111
1112
1113
1114
1115
1116
1117
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
1118
1119
1120
1121
1122
1123
1124
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1125
1126
1127
1128
1129
1130
1131
1132
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
1133
1134
1135
1136
class TestFusedAttn:
    """
    Fused attention tester
    """
1137

zlsh80826's avatar
zlsh80826 committed
1138
    @staticmethod
1139
1140
1141
1142
1143
1144
1145
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1156
    def _test_forward(
1157
1158
1159
1160
1161
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1162
1163
        d_qk,
        d_v,
1164
1165
        attn_bias_type,
        attn_mask_type,
1166
        softmax_type,
1167
1168
1169
1170
1171
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1172
        swa,
1173
        seq_desc_format,
1174
    ):
1175
        """
zlsh80826's avatar
zlsh80826 committed
1176
        Test forward with parameterized configs
1177
1178
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1179
        """
1180
1181
1182
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1183
1184
1185
1186
1187
1188
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1189
1190
            d_qk,
            d_v,
1191
1192
            attn_bias_type,
            attn_mask_type,
1193
            softmax_type,
1194
1195
1196
1197
1198
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1199
            window_size,
1200
            seq_desc_format,
1201
        )
zlsh80826's avatar
zlsh80826 committed
1202
        runner.test_forward()
1203

zlsh80826's avatar
zlsh80826 committed
1204
    @staticmethod
1205
1206
1207
1208
1209
1210
1211
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1212
1213
1214
1215
1216
1217
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1218
1219
        d_qk,
        d_v,
1220
1221
        attn_bias_type,
        attn_mask_type,
1222
        softmax_type,
1223
1224
1225
1226
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1227
        swa,
1228
        seq_desc_format,
1229
    ):
zlsh80826's avatar
zlsh80826 committed
1230
1231
1232
        """
        Test backward with parameterized configs
        """
1233
1234
1235
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1236
1237
1238
1239
1240
1241
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1242
1243
            d_qk,
            d_v,
1244
1245
            attn_bias_type,
            attn_mask_type,
1246
            softmax_type,
1247
1248
1249
1250
1251
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1252
            window_size,
1253
            seq_desc_format,
1254
        )
zlsh80826's avatar
zlsh80826 committed
1255
        runner.test_backward()