test_fused_attn.py 42.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
26
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
27
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
31
    QKVFormat,
32
33
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
34
    fused_attn,
35
    make_swa_mask,
36
    SequenceDescriptor,
37
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
38
    ReorderStrategy,
39
)
40
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
41
from transformer_engine_jax import (
42
43
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
44
    get_device_compute_capability,
45
)
46

47
48
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
49

50

51
@pytest.fixture(autouse=True, scope="module")
52
def init():
53
    """
54
    WAR for CUDA uninitialize error
55
    """
zlsh80826's avatar
zlsh80826 committed
56
57
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
58
59
60
    yield


61
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
62
63
64
65
66
67
68
69
70
71
72
73
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
74
    """
zlsh80826's avatar
zlsh80826 committed
75
    Similar to flax.linen.dot_product_attention but with GQA support
76
    """
zlsh80826's avatar
zlsh80826 committed
77
78
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
79

zlsh80826's avatar
zlsh80826 committed
80
    b, s_q, h_q, d = query.shape
81
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
82
83
84
85
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
86
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
87
88

    if bias is not None:
89
90
91
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
92
        logits = logits + bias
93
94
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
95
96
97
98

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
99
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
100

zlsh80826's avatar
zlsh80826 committed
101
    softmax_out = jax.nn.softmax(logits).astype(dtype)
102

103
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
104
105
106
107
108
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

109
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
110
111
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
112
    return context
113
114


115
116
117
118
119
120
121
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
122
    """
123
124
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
125
    If segment_pos is not provided, aragne of the segment_ids will be applied.
126
    """
127
128
129
130
131
132
133
134
135
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
136
    return inv_causal_mask
137

138

139
@partial(jax.jit, static_argnums=(4, 5))
140
def make_mask(
141
142
143
144
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
145
    attn_mask_type: AttnMaskType,
146
    window_size: Optional[Tuple[int, int]] = None,
147
) -> Array:
148
149
150
151
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
152
153
154
155
156
157
158
159

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
160
    """
161
    # segment masks
162
    inv_mask = make_attention_mask(
163
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
164
    )
165
166
167
168
169
170
171
172
173
174
175

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

    # causal mask
176
177
178
    if attn_mask_type.is_causal():
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
179
        )
180
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
181

182
183
184
    # sliding window mask
    inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
185
186
    mask = jnp.logical_not(inv_mask)
    return mask
187

188

189
190
@jax.jit
def get_seqlens_and_offsets(segment_ids):
191
192
193
194
195
196
197
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
198
        first_column = x[..., :1] != 0
199
200
201
202
203
204
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
205
206
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
207
208
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
209
210
211
212
213
214
215
216
217
218
219
220
221


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
222
    """
zlsh80826's avatar
zlsh80826 committed
223
    JAX native dot product attention implementation
224
    """
225
226
227
228
    output = general_dot_product_attention(
        query,
        key,
        value,
229
230
        bias,
        mask,
231
232
233
234
235
236
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
237
    return output.astype(query.dtype)
238
239


240
241
242
243
244
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
245
    sequence_descriptor,
246
247
248
    dropout_rng,
    **kwargs,
):
249
    """
zlsh80826's avatar
zlsh80826 committed
250
    TE customcall dot product attention implementation
251
    """
252
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
253
    match qkv_layout:
254
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
255
256
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
257
258
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
259
260
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
261
262
263
264
265
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
266
267
268
    return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
        query.dtype
    )
zlsh80826's avatar
zlsh80826 committed
269
270


271
class BiasShape(Enum):
272
273
274
275
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

276
277
278
279
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
280
281


282
283
284
285
286
287
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
288
289
@dataclass
class FusedAttnRunner:
290
    """
zlsh80826's avatar
zlsh80826 committed
291
    Fused attention runner
292
    """
293

zlsh80826's avatar
zlsh80826 committed
294
295
296
297
298
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
299
300
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
301
302
303
304
305
306
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
307
    bias_shape: BiasShape
308
309
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
zlsh80826's avatar
zlsh80826 committed
310

311
312
313
314
315
316
317
318
319
320
321
322
323
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

324
325
326
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
327
328
329
330
331
332
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
333
        else:
334
            return 1
335

zlsh80826's avatar
zlsh80826 committed
336
    def _check_configs(self):
337
        # TODO(rewang): probably adds this in is_fused_attn_available
338
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
339
340
            pytest.skip("THD format requires padding masks.")

341
        if self.qkv_layout.is_qkvpacked():
342
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
343
344
345
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
346

347
348
349
350
351
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

352
353
354
355
356
357
358
359
        if (
            get_device_compute_capability(0) == 100
            and self.dropout_prob == 0.1
            and self.attn_bias_type is not AttnBiasType.NO_BIAS
        ):
            pytest.skip(
                "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
            )
360
361
362
363
364
365
366
367
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

368
        self.backend = FusedAttnHelper(
369
            self.is_training,
370
371
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
372
373
374
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
375
376
377
378
379
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
380
381
            self.head_dim_qk,
            self.head_dim_v,
382
            (-1, -1) if self.window_size is None else self.window_size,
383
        ).get_fused_attn_backend()
384
        if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
zlsh80826's avatar
zlsh80826 committed
385
            pytest.skip("Unsupported inputs combination or device compute capability.")
386

387
388
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
389
            and self.bias_shape != BiasShape._1HSS
390
        ):
391
            if self.attn_mask_type.is_padding():
392
                pytest.skip(
393
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
394
                )
395
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
396
397
398
399
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
400

zlsh80826's avatar
zlsh80826 committed
401
402
    def _setup_inputs(self):
        self._check_configs()
403
404
405
406
407
408

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
409
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
410

zlsh80826's avatar
zlsh80826 committed
411
412
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
413

414
415
416
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
417

418
419
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
420
        elif self.bias_shape == BiasShape._1HSS:
421
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
422
        elif self.bias_shape == BiasShape._B1SS:
423
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
424
        elif self.bias_shape == BiasShape._BHSS:
425
426
427
428
429
430
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
431
        elif self.bias_shape == BiasShape._11SS:
432
433
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
434
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
435

436
437
438
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
439
440

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
441
            if self.bias_shape == BiasShape._1HSS:
442
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
443
444
445
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
446
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
447
448
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
449
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
450
451
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
452
453
454
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
455
456
        else:
            self.bias = None
457

458
        if self.attn_mask_type.is_padding():
459
            pad_ratio = 0.3
460
461
        else:
            pad_ratio = 0.0
462

zlsh80826's avatar
zlsh80826 committed
463
464
465
466
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
467
468
469
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
470
471
472
473
474
475
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
476
477
478
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
479
480
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
481
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
482
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
483
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
484
485
486
487
488
489
490

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

491
492
493
494
495
496
497
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
498
499
500
501
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
502
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
503
                    if with_segment_pad:
504
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
505
506
507
508
509
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

510
511
512
513
514
515
516
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
517
            self.num_segments_per_seq = 2
518
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
519
520
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
521
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
522
523
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
524
525
526
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
527
            else:
528
529
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
                min_segment_len = None if self.window_size is None else self.seqlens_q
530
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
531
532
533
534
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
535
                    min_segment_len=min_segment_len,
536
                )
537
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
538
539
        else:
            self.num_segments_per_seq = 1
540
541
542
543
544
545
546
547
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
548

549
        # For reference code
550
        self.mask = make_mask(
551
552
553
554
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
555
            self.attn_mask_type,
556
            self.window_size,
557
        )
558

Reese Wang's avatar
Reese Wang committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

583
        # Test different input formats
584
        if self.qkv_layout.is_thd():
585
586
587
588
589
590
591
592
593
594
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
595
596
597
598
599
600
601
602
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
                            self.cp_reorder_fn(self.segment_pos_q),
                            self.cp_reorder_fn(self.segment_pos_kv),
                        ),
603
604
605
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
606
        else:
607
608
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
609
610
611
612
613
614
615
616
617
618
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
633

zlsh80826's avatar
zlsh80826 committed
634
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
635
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
636

637
638
639
640
641
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
642
            self.mesh_resource.tpsp_resource,
643
644
645
646
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

647
        mask_pspec = PartitionSpec(
648
649
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
650
651
652
653
654
655
656
657
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
658
659
660
661
662
663
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
664
665
666
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
667
668
669

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
670
                None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
695
696
697
698
699
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
700

701
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
702

703
        customcall_args = [
704
705
706
707
708
709
710
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
711
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
712
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
713
        ]
zlsh80826's avatar
zlsh80826 committed
714
        kwargs = {
715
716
717
718
719
720
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
721
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
722
            "window_size": self.window_size,
723
724
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
725
        }
726

727
728
729
730
731
732
733
734
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
735
                self.seq_desc_sharding,
736
737
738
739
740
741
742
743
                self.dropout_rng_sharding,
            ],
        )

        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

744
        reference_out = jax_dpa(*args, **kwargs)
745

746
        if self.is_training and self.dropout_prob > 0.0:
747
748
            return

749
750
751
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
752
753
754

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
755

756
757
758
759
760
761
762
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
763
    def test_backward(self):
764
        """
765
766
767
768
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
769
        """
zlsh80826's avatar
zlsh80826 committed
770
771
772

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
773
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
774
            # Gradient is small, use a gradient multiplier to amplify the gradient
775
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
776
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
777
                gradient_multiplier /= 10
778
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
779
780
781
782
783
784
785
786
787
788
789
790
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
791
792
793
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
794

795
796
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
797
798
799
800
801
802
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
803
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
804
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
805
        ]
806
        kwargs = {
807
808
809
810
811
812
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
813
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
814
            "window_size": self.window_size,
815
816
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
817
818
        }

819
        # We can compute dBias only for the [1, h, s, s] layout
820
821
822
823
824
825
826
827
828
829
830
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
831

832
833
834
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
835
                lambda q, k, v, bias, *args: grad_func(
Reese Wang's avatar
Reese Wang committed
836
                    customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
837
838
                ),
                arg_nums,
839
840
841
842
843
844
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
845
                self.seq_desc_sharding,
846
847
848
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
849
        )
850
851
        jitted_reference = jit(
            value_and_grad(
852
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
853
854
855
                arg_nums,
            )
        )
856

857
858
859
        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
860
        reference_out, reference_dgrad = jitted_reference(*args)
861

zlsh80826's avatar
zlsh80826 committed
862
        # Skip elementwise comparison when dropout enabled
863
        if self.dropout_prob > 0.0:
864
865
            return

866
867
868
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
869
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
870

871
        def check_dqkv(primitive, reference, pad, idx):
872
873
874
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
875

876
877
878
879
880
881
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

882
883
884
885
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

886
887
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
888

889
890
891
892
893
894
895
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
896

897
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
898
899
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

900
901
902
903
904
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
905

906
            # Assert all masked dbias are 0s
907
            assert_allclose(
908
909
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
910
911
                dtype=self.dtype,
            )
912

913
            # dbias padded part
914
            assert_allclose(
915
916
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
917
918
                dtype=self.dtype,
            )
919

920
            # dbias valid part
921
            assert_allclose(
922
923
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
924
925
926
                dtype=self.dtype,
            )

927
928
929
930
931
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
948
949
950
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
951
952
953
    ],
)
@pytest.mark.parametrize(
954
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
955
    [
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
        pytest.param(
            2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
        ),
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            id="2-2048-1024-12-12-64-64-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
        ),
979
980
981
982
983
984
985
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
986
            32,
987
            jnp.bfloat16,
988
989
990
991
            id="2-2048-1024-12-12-64-32-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
992
        ),
993
994
995
996
997
998
999
1000
1001
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
1002
1003
1004
1005
1006
1007
1008
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1009
1010
1011
1012
1013
1014
1015
1016
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
1017
1018
1019
1020
class TestFusedAttn:
    """
    Fused attention tester
    """
1021

zlsh80826's avatar
zlsh80826 committed
1022
    @staticmethod
1023
1024
1025
1026
1027
1028
1029
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1040
    def _test_forward(
1041
1042
1043
1044
1045
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1046
1047
        d_qk,
        d_v,
1048
1049
1050
1051
1052
1053
1054
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1055
        swa,
1056
        seq_desc_format,
1057
    ):
1058
        """
zlsh80826's avatar
zlsh80826 committed
1059
        Test forward with parameterized configs
1060
1061
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1062
        """
1063
1064
1065
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1066
1067
1068
1069
1070
1071
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1072
1073
            d_qk,
            d_v,
1074
1075
1076
1077
1078
1079
1080
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1081
            window_size,
1082
            seq_desc_format,
1083
        )
zlsh80826's avatar
zlsh80826 committed
1084
        runner.test_forward()
1085

zlsh80826's avatar
zlsh80826 committed
1086
    @staticmethod
1087
1088
1089
1090
1091
1092
1093
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1094
1095
1096
1097
1098
1099
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1100
1101
        d_qk,
        d_v,
1102
1103
1104
1105
1106
1107
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1108
        swa,
1109
        seq_desc_format,
1110
    ):
zlsh80826's avatar
zlsh80826 committed
1111
1112
1113
        """
        Test backward with parameterized configs
        """
1114
1115
1116
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1117
1118
1119
1120
1121
1122
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1123
1124
            d_qk,
            d_v,
1125
1126
1127
1128
1129
1130
1131
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1132
            window_size,
1133
            seq_desc_format,
1134
        )
zlsh80826's avatar
zlsh80826 committed
1135
        runner.test_backward()