test_fused_attn.py 40.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
26
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
27
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
31
    QKVFormat,
32
33
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
34
    fused_attn,
35
    make_swa_mask,
36
    SequenceDescriptor,
37
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
38
    ReorderStrategy,
39
)
40
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
41
from transformer_engine_jax import (
42
43
44
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
45

46
47
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
48

49

50
@pytest.fixture(autouse=True, scope="module")
51
def init():
52
    """
53
    WAR for CUDA uninitialize error
54
    """
zlsh80826's avatar
zlsh80826 committed
55
56
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
57
58
59
    yield


60
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
61
62
63
64
65
66
67
68
69
70
71
72
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
73
    """
zlsh80826's avatar
zlsh80826 committed
74
    Similar to flax.linen.dot_product_attention but with GQA support
75
    """
zlsh80826's avatar
zlsh80826 committed
76
77
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
78

zlsh80826's avatar
zlsh80826 committed
79
    b, s_q, h_q, d = query.shape
80
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
81
82
83
84
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
85
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
86
87

    if bias is not None:
88
89
90
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
91
        logits = logits + bias
92
93
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
94
95
96
97

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
98
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
99

zlsh80826's avatar
zlsh80826 committed
100
    softmax_out = jax.nn.softmax(logits).astype(dtype)
101

102
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
103
104
105
106
107
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

108
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
109
110
    context = jnp.reshape(context, query.shape)
    return context
111
112


113
114
115
116
117
118
119
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
120
    """
121
122
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
123
    If segment_pos is not provided, aragne of the segment_ids will be applied.
124
    """
125
126
127
128
129
130
131
132
133
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
134
    return inv_causal_mask
135

136

137
@partial(jax.jit, static_argnums=(4, 5))
138
def make_mask(
139
140
141
142
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
143
    attn_mask_type: AttnMaskType,
144
    window_size: Optional[Tuple[int, int]] = None,
145
) -> Array:
146
147
148
149
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
150
151
152
153
154
155
156
157

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
158
    """
159
    # segment masks
160
    inv_mask = make_attention_mask(
161
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
162
    )
163
164
165
166
167
168
169
170
171
172
173

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

    # causal mask
174
175
176
    if attn_mask_type.is_causal():
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
177
        )
178
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
179

180
181
182
    # sliding window mask
    inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
183
184
    mask = jnp.logical_not(inv_mask)
    return mask
185

186

187
188
@jax.jit
def get_seqlens_and_offsets(segment_ids):
189
190
191
192
193
194
195
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
196
        first_column = x[..., :1] != 0
197
198
199
200
201
202
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
203
204
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
205
206
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
207
208
209
210
211
212
213
214
215
216
217
218
219


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
220
    """
zlsh80826's avatar
zlsh80826 committed
221
    JAX native dot product attention implementation
222
    """
223
224
225
226
    output = general_dot_product_attention(
        query,
        key,
        value,
227
228
        bias,
        mask,
229
230
231
232
233
234
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
235
    return output.astype(query.dtype)
236
237


238
239
240
241
242
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
243
    sequence_descriptor,
244
245
246
    dropout_rng,
    **kwargs,
):
247
    """
zlsh80826's avatar
zlsh80826 committed
248
    TE customcall dot product attention implementation
249
    """
250
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
251
    match qkv_layout:
252
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
253
254
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
255
256
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
257
258
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
259
260
261
262
263
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
264
265
266
    return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
        query.dtype
    )
zlsh80826's avatar
zlsh80826 committed
267
268


269
class BiasShape(Enum):
270
271
272
273
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

274
275
276
277
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
278
279


280
281
282
283
284
285
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
286
287
@dataclass
class FusedAttnRunner:
288
    """
zlsh80826's avatar
zlsh80826 committed
289
    Fused attention runner
290
    """
291

zlsh80826's avatar
zlsh80826 committed
292
293
294
295
296
297
298
299
300
301
302
303
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
304
    bias_shape: BiasShape
305
306
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
zlsh80826's avatar
zlsh80826 committed
307

308
309
310
311
312
313
314
315
316
317
318
319
320
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

321
322
323
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
324
325
326
327
328
329
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
330
        else:
331
            return 1
332

zlsh80826's avatar
zlsh80826 committed
333
    def _check_configs(self):
334
        # TODO(rewang): probably adds this in is_fused_attn_available
335
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
336
337
            pytest.skip("THD format requires padding masks.")

338
        if self.qkv_layout.is_qkvpacked():
339
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
340
341
342
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
343

344
345
346
347
348
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

349
        self.backend = FusedAttnHelper(
350
            self.is_training,
351
352
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
353
354
355
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
356
357
358
359
360
361
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
362
            (-1, -1) if self.window_size is None else self.window_size,
363
        ).get_fused_attn_backend()
364
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
365
            pytest.skip("Unsupported inputs combination or device compute capability.")
366

367
368
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
369
            and self.bias_shape != BiasShape._1HSS
370
        ):
371
            if self.attn_mask_type.is_padding():
372
                pytest.skip(
373
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
374
                )
375
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
376
377
378
379
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
380

zlsh80826's avatar
zlsh80826 committed
381
382
    def _setup_inputs(self):
        self._check_configs()
383
384
385
386
387
388
389
390

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)

zlsh80826's avatar
zlsh80826 committed
391
392
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
393

zlsh80826's avatar
zlsh80826 committed
394
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
395
396
397
398
399
400
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
401

402
403
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
404
        elif self.bias_shape == BiasShape._1HSS:
405
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
406
        elif self.bias_shape == BiasShape._B1SS:
407
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
408
        elif self.bias_shape == BiasShape._BHSS:
409
410
411
412
413
414
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
415
        elif self.bias_shape == BiasShape._11SS:
416
417
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
418
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
419

420
421
422
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
423
424

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
425
            if self.bias_shape == BiasShape._1HSS:
426
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
427
428
429
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
430
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
431
432
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
433
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
434
435
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
436
437
438
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
439
440
        else:
            self.bias = None
441

442
        if self.attn_mask_type.is_padding():
443
            pad_ratio = 0.3
444
445
        else:
            pad_ratio = 0.0
446

zlsh80826's avatar
zlsh80826 committed
447
448
449
450
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
451
452
453
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
454
455
456
457
458
459
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
460
461
462
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
463
464
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
465
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
466
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
467
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
468
469
470
471
472
473
474

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

475
476
477
478
479
480
481
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
482
483
484
485
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
486
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
487
                    if with_segment_pad:
488
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
489
490
491
492
493
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

494
495
496
497
498
499
500
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
501
            self.num_segments_per_seq = 2
502
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
503
504
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
505
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
506
507
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
508
509
510
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
511
            else:
512
513
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
                min_segment_len = None if self.window_size is None else self.seqlens_q
514
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
515
516
517
518
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
519
                    min_segment_len=min_segment_len,
520
                )
521
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
522
523
        else:
            self.num_segments_per_seq = 1
524
525
526
527
528
529
530
531
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
532

533
        # For reference code
534
        self.mask = make_mask(
535
536
537
538
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
539
            self.attn_mask_type,
540
            self.window_size,
541
        )
542

Reese Wang's avatar
Reese Wang committed
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

567
        # Test different input formats
568
        if self.qkv_layout.is_thd():
569
570
571
572
573
574
575
576
577
578
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
579
580
581
582
583
584
585
586
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
                            self.cp_reorder_fn(self.segment_pos_q),
                            self.cp_reorder_fn(self.segment_pos_kv),
                        ),
587
588
589
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
590
        else:
591
592
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
593
594
595
596
597
598
599
600
601
602
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
603
604
605
606
607
608
609
610
611
612
613
614
615
616
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
617

zlsh80826's avatar
zlsh80826 committed
618
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
619
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
620

621
622
623
624
625
626
627
628
629
630
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
            self.mesh_resource.tp_resource,
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

631
        mask_pspec = PartitionSpec(
632
633
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
634
635
636
637
638
639
640
641
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
642
643
644
645
646
647
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
648
649
650
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
                None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
679
680
681
682
683
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
684

685
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
686

687
        customcall_args = [
688
689
690
691
692
693
694
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
695
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
696
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
697
        ]
zlsh80826's avatar
zlsh80826 committed
698
        kwargs = {
699
700
701
702
703
704
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
705
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
706
            "window_size": self.window_size,
707
708
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
709
        }
710

711
712
713
714
715
716
717
718
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
719
                self.seq_desc_sharding,
720
721
722
723
724
725
726
727
                self.dropout_rng_sharding,
            ],
        )

        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

728
        reference_out = jax_dpa(*args, **kwargs)
729

730
        if self.is_training and self.dropout_prob > 0.0:
731
732
            return

733
734
735
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
736
737
738

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
739

740
741
742
743
744
745
746
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
747
    def test_backward(self):
748
        """
749
750
751
752
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
753
        """
zlsh80826's avatar
zlsh80826 committed
754
755
756

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
757
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
758
            # Gradient is small, use a gradient multiplier to amplify the gradient
759
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
760
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
761
                gradient_multiplier /= 10
762
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
763
764
765
766
767
768
769
770
771
772
773
774
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
775
776
777
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
778

779
780
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
781
782
783
784
785
786
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
787
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
788
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
789
        ]
790
        kwargs = {
791
792
793
794
795
796
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
797
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
798
            "window_size": self.window_size,
799
800
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
801
802
        }

803
        # We can compute dBias only for the [1, h, s, s] layout
804
805
806
807
808
809
810
811
812
813
814
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
815

816
817
818
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
819
                lambda q, k, v, bias, *args: grad_func(
Reese Wang's avatar
Reese Wang committed
820
                    customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
821
822
                ),
                arg_nums,
823
824
825
826
827
828
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
829
                self.seq_desc_sharding,
830
831
832
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
833
        )
834
835
        jitted_reference = jit(
            value_and_grad(
836
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
837
838
839
                arg_nums,
            )
        )
840

841
842
843
        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
844
        reference_out, reference_dgrad = jitted_reference(*args)
845

zlsh80826's avatar
zlsh80826 committed
846
        # Skip elementwise comparison when dropout enabled
847
        if self.dropout_prob > 0.0:
848
849
            return

850
851
852
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
853
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
854

855
        def check_dqkv(primitive, reference, pad, idx):
856
857
858
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
859

860
861
862
863
864
865
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

866
867
868
869
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

870
871
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
872

873
874
875
876
877
878
879
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
880

881
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
882
883
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

884
885
886
887
888
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
889

890
            # Assert all masked dbias are 0s
891
            assert_allclose(
892
893
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
894
895
                dtype=self.dtype,
            )
896

897
            # dbias padded part
898
            assert_allclose(
899
900
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
901
902
                dtype=self.dtype,
            )
903

904
            # dbias valid part
905
            assert_allclose(
906
907
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
908
909
910
                dtype=self.dtype,
            )

911
912
913
914
915
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
932
933
934
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
935
936
937
    ],
)
@pytest.mark.parametrize(
938
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
939
    [
940
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
941
942
943
944
945
946
947
948
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
949
            id="2-2048-1024-12-12-64-BF16-CROSS",
950
        ),
951
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
952
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
953
954
955
956
957
958
959
960
961
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
962
963
964
965
966
967
968
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
969
970
971
972
973
974
975
976
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
977
978
979
980
class TestFusedAttn:
    """
    Fused attention tester
    """
981

zlsh80826's avatar
zlsh80826 committed
982
    @staticmethod
983
984
985
986
987
988
989
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
990
991
992
993
994
995
996
997
998
999
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1000
    def _test_forward(
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1014
        swa,
1015
        seq_desc_format,
1016
    ):
1017
        """
zlsh80826's avatar
zlsh80826 committed
1018
        Test forward with parameterized configs
1019
1020
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1021
        """
1022
1023
1024
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1039
            window_size,
1040
            seq_desc_format,
1041
        )
zlsh80826's avatar
zlsh80826 committed
1042
        runner.test_forward()
1043

zlsh80826's avatar
zlsh80826 committed
1044
    @staticmethod
1045
1046
1047
1048
1049
1050
1051
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1065
        swa,
1066
        seq_desc_format,
1067
    ):
zlsh80826's avatar
zlsh80826 committed
1068
1069
1070
        """
        Test backward with parameterized configs
        """
1071
1072
1073
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1088
            window_size,
1089
            seq_desc_format,
1090
        )
zlsh80826's avatar
zlsh80826 committed
1091
        runner.test_backward()