test_fused_attn.py 41.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
26
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
27
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
31
    QKVFormat,
32
33
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
34
    fused_attn,
35
    make_swa_mask,
36
    SequenceDescriptor,
37
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
38
    ReorderStrategy,
39
)
40
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
41
from transformer_engine_jax import (
42
43
44
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
45

46
47
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
48

49

50
@pytest.fixture(autouse=True, scope="module")
51
def init():
52
    """
53
    WAR for CUDA uninitialize error
54
    """
zlsh80826's avatar
zlsh80826 committed
55
56
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
57
58
59
    yield


60
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
61
62
63
64
65
66
67
68
69
70
71
72
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
73
    """
zlsh80826's avatar
zlsh80826 committed
74
    Similar to flax.linen.dot_product_attention but with GQA support
75
    """
zlsh80826's avatar
zlsh80826 committed
76
77
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
78

zlsh80826's avatar
zlsh80826 committed
79
    b, s_q, h_q, d = query.shape
80
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
81
82
83
84
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
85
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
86
87

    if bias is not None:
88
89
90
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
91
        logits = logits + bias
92
93
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
94
95
96
97

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
98
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
99

zlsh80826's avatar
zlsh80826 committed
100
    softmax_out = jax.nn.softmax(logits).astype(dtype)
101

102
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
103
104
105
106
107
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

108
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
109
110
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
111
    return context
112
113


114
115
116
117
118
119
120
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
121
    """
122
123
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
124
    If segment_pos is not provided, aragne of the segment_ids will be applied.
125
    """
126
127
128
129
130
131
132
133
134
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
135
    return inv_causal_mask
136

137

138
@partial(jax.jit, static_argnums=(4, 5))
139
def make_mask(
140
141
142
143
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
144
    attn_mask_type: AttnMaskType,
145
    window_size: Optional[Tuple[int, int]] = None,
146
) -> Array:
147
148
149
150
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
151
152
153
154
155
156
157
158

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
159
    """
160
    # segment masks
161
    inv_mask = make_attention_mask(
162
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
163
    )
164
165
166
167
168
169
170
171
172
173
174

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

    # causal mask
175
176
177
    if attn_mask_type.is_causal():
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
178
        )
179
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
180

181
182
183
    # sliding window mask
    inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
184
185
    mask = jnp.logical_not(inv_mask)
    return mask
186

187

188
189
@jax.jit
def get_seqlens_and_offsets(segment_ids):
190
191
192
193
194
195
196
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
197
        first_column = x[..., :1] != 0
198
199
200
201
202
203
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
204
205
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
206
207
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
208
209
210
211
212
213
214
215
216
217
218
219
220


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
221
    """
zlsh80826's avatar
zlsh80826 committed
222
    JAX native dot product attention implementation
223
    """
224
225
226
227
    output = general_dot_product_attention(
        query,
        key,
        value,
228
229
        bias,
        mask,
230
231
232
233
234
235
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
236
    return output.astype(query.dtype)
237
238


239
240
241
242
243
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
244
    sequence_descriptor,
245
246
247
    dropout_rng,
    **kwargs,
):
248
    """
zlsh80826's avatar
zlsh80826 committed
249
    TE customcall dot product attention implementation
250
    """
251
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
252
    match qkv_layout:
253
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
254
255
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
256
257
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
258
259
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
260
261
262
263
264
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
265
266
267
    return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
        query.dtype
    )
zlsh80826's avatar
zlsh80826 committed
268
269


270
class BiasShape(Enum):
271
272
273
274
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

275
276
277
278
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
279
280


281
282
283
284
285
286
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
287
288
@dataclass
class FusedAttnRunner:
289
    """
zlsh80826's avatar
zlsh80826 committed
290
    Fused attention runner
291
    """
292

zlsh80826's avatar
zlsh80826 committed
293
294
295
296
297
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
298
299
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
300
301
302
303
304
305
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
306
    bias_shape: BiasShape
307
308
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
zlsh80826's avatar
zlsh80826 committed
309

310
311
312
313
314
315
316
317
318
319
320
321
322
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

323
324
325
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
326
327
328
329
330
331
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
332
        else:
333
            return 1
334

zlsh80826's avatar
zlsh80826 committed
335
    def _check_configs(self):
336
        # TODO(rewang): probably adds this in is_fused_attn_available
337
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
338
339
            pytest.skip("THD format requires padding masks.")

340
        if self.qkv_layout.is_qkvpacked():
341
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
342
343
344
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
345

346
347
348
349
350
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

351
352
353
354
355
356
357
358
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

359
        self.backend = FusedAttnHelper(
360
            self.is_training,
361
362
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
363
364
365
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
366
367
368
369
370
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
371
372
            self.head_dim_qk,
            self.head_dim_v,
373
            (-1, -1) if self.window_size is None else self.window_size,
374
        ).get_fused_attn_backend()
375
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
376
            pytest.skip("Unsupported inputs combination or device compute capability.")
377

378
379
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
380
            and self.bias_shape != BiasShape._1HSS
381
        ):
382
            if self.attn_mask_type.is_padding():
383
                pytest.skip(
384
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
385
                )
386
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
387
388
389
390
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
391

zlsh80826's avatar
zlsh80826 committed
392
393
    def _setup_inputs(self):
        self._check_configs()
394
395
396
397
398
399
400
401

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)

zlsh80826's avatar
zlsh80826 committed
402
403
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
404

405
406
407
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
408

409
410
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
411
        elif self.bias_shape == BiasShape._1HSS:
412
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
413
        elif self.bias_shape == BiasShape._B1SS:
414
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
415
        elif self.bias_shape == BiasShape._BHSS:
416
417
418
419
420
421
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
422
        elif self.bias_shape == BiasShape._11SS:
423
424
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
425
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
426

427
428
429
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
430
431

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
432
            if self.bias_shape == BiasShape._1HSS:
433
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
434
435
436
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
437
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
438
439
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
440
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
441
442
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
443
444
445
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
446
447
        else:
            self.bias = None
448

449
        if self.attn_mask_type.is_padding():
450
            pad_ratio = 0.3
451
452
        else:
            pad_ratio = 0.0
453

zlsh80826's avatar
zlsh80826 committed
454
455
456
457
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
458
459
460
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
461
462
463
464
465
466
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
467
468
469
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
470
471
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
472
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
473
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
474
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
475
476
477
478
479
480
481

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

482
483
484
485
486
487
488
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
489
490
491
492
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
493
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
494
                    if with_segment_pad:
495
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
496
497
498
499
500
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

501
502
503
504
505
506
507
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
508
            self.num_segments_per_seq = 2
509
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
510
511
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
512
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
513
514
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
515
516
517
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
518
            else:
519
520
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
                min_segment_len = None if self.window_size is None else self.seqlens_q
521
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
522
523
524
525
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
526
                    min_segment_len=min_segment_len,
527
                )
528
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
529
530
        else:
            self.num_segments_per_seq = 1
531
532
533
534
535
536
537
538
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
539

540
        # For reference code
541
        self.mask = make_mask(
542
543
544
545
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
546
            self.attn_mask_type,
547
            self.window_size,
548
        )
549

Reese Wang's avatar
Reese Wang committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

574
        # Test different input formats
575
        if self.qkv_layout.is_thd():
576
577
578
579
580
581
582
583
584
585
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
586
587
588
589
590
591
592
593
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
                            self.cp_reorder_fn(self.segment_pos_q),
                            self.cp_reorder_fn(self.segment_pos_kv),
                        ),
594
595
596
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
597
        else:
598
599
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
600
601
602
603
604
605
606
607
608
609
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
610
611
612
613
614
615
616
617
618
619
620
621
622
623
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
624

zlsh80826's avatar
zlsh80826 committed
625
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
626
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
627

628
629
630
631
632
633
634
635
636
637
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
            self.mesh_resource.tp_resource,
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

638
        mask_pspec = PartitionSpec(
639
640
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
641
642
643
644
645
646
647
648
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
649
650
651
652
653
654
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
655
656
657
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
                None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
686
687
688
689
690
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
691

692
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
693

694
        customcall_args = [
695
696
697
698
699
700
701
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
702
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
703
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
704
        ]
zlsh80826's avatar
zlsh80826 committed
705
        kwargs = {
706
707
708
709
710
711
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
712
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
713
            "window_size": self.window_size,
714
715
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
716
        }
717

718
719
720
721
722
723
724
725
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
726
                self.seq_desc_sharding,
727
728
729
730
731
732
733
734
                self.dropout_rng_sharding,
            ],
        )

        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

735
        reference_out = jax_dpa(*args, **kwargs)
736

737
        if self.is_training and self.dropout_prob > 0.0:
738
739
            return

740
741
742
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
743
744
745

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
746

747
748
749
750
751
752
753
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
754
    def test_backward(self):
755
        """
756
757
758
759
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
760
        """
zlsh80826's avatar
zlsh80826 committed
761
762
763

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
764
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
765
            # Gradient is small, use a gradient multiplier to amplify the gradient
766
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
767
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
768
                gradient_multiplier /= 10
769
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
770
771
772
773
774
775
776
777
778
779
780
781
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
782
783
784
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
785

786
787
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
788
789
790
791
792
793
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
794
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
795
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
796
        ]
797
        kwargs = {
798
799
800
801
802
803
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
804
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
805
            "window_size": self.window_size,
806
807
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
808
809
        }

810
        # We can compute dBias only for the [1, h, s, s] layout
811
812
813
814
815
816
817
818
819
820
821
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
822

823
824
825
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
826
                lambda q, k, v, bias, *args: grad_func(
Reese Wang's avatar
Reese Wang committed
827
                    customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
828
829
                ),
                arg_nums,
830
831
832
833
834
835
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
836
                self.seq_desc_sharding,
837
838
839
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
840
        )
841
842
        jitted_reference = jit(
            value_and_grad(
843
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
844
845
846
                arg_nums,
            )
        )
847

848
849
850
        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
851
        reference_out, reference_dgrad = jitted_reference(*args)
852

zlsh80826's avatar
zlsh80826 committed
853
        # Skip elementwise comparison when dropout enabled
854
        if self.dropout_prob > 0.0:
855
856
            return

857
858
859
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
860
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
861

862
        def check_dqkv(primitive, reference, pad, idx):
863
864
865
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
866

867
868
869
870
871
872
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

873
874
875
876
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

877
878
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
879

880
881
882
883
884
885
886
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
887

888
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
889
890
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

891
892
893
894
895
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
896

897
            # Assert all masked dbias are 0s
898
            assert_allclose(
899
900
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
901
902
                dtype=self.dtype,
            )
903

904
            # dbias padded part
905
            assert_allclose(
906
907
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
908
909
                dtype=self.dtype,
            )
910

911
            # dbias valid part
912
            assert_allclose(
913
914
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
915
916
917
                dtype=self.dtype,
            )

918
919
920
921
922
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
939
940
941
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
942
943
944
    ],
)
@pytest.mark.parametrize(
945
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
946
    [
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
        pytest.param(
            2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
        ),
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            id="2-2048-1024-12-12-64-64-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
        ),
970
971
972
973
974
975
976
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
977
            32,
978
            jnp.bfloat16,
979
980
981
982
            id="2-2048-1024-12-12-64-32-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
983
        ),
984
985
986
987
988
989
990
991
992
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
993
994
995
996
997
998
999
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1000
1001
1002
1003
1004
1005
1006
1007
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
1008
1009
1010
1011
class TestFusedAttn:
    """
    Fused attention tester
    """
1012

zlsh80826's avatar
zlsh80826 committed
1013
    @staticmethod
1014
1015
1016
1017
1018
1019
1020
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1031
    def _test_forward(
1032
1033
1034
1035
1036
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1037
1038
        d_qk,
        d_v,
1039
1040
1041
1042
1043
1044
1045
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1046
        swa,
1047
        seq_desc_format,
1048
    ):
1049
        """
zlsh80826's avatar
zlsh80826 committed
1050
        Test forward with parameterized configs
1051
1052
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1053
        """
1054
1055
1056
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1057
1058
1059
1060
1061
1062
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1063
1064
            d_qk,
            d_v,
1065
1066
1067
1068
1069
1070
1071
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1072
            window_size,
1073
            seq_desc_format,
1074
        )
zlsh80826's avatar
zlsh80826 committed
1075
        runner.test_forward()
1076

zlsh80826's avatar
zlsh80826 committed
1077
    @staticmethod
1078
1079
1080
1081
1082
1083
1084
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1085
1086
1087
1088
1089
1090
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1091
1092
        d_qk,
        d_v,
1093
1094
1095
1096
1097
1098
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1099
        swa,
1100
        seq_desc_format,
1101
    ):
zlsh80826's avatar
zlsh80826 committed
1102
1103
1104
        """
        Test backward with parameterized configs
        """
1105
1106
1107
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1108
1109
1110
1111
1112
1113
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1114
1115
            d_qk,
            d_v,
1116
1117
1118
1119
1120
1121
1122
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1123
            window_size,
1124
            seq_desc_format,
1125
        )
zlsh80826's avatar
zlsh80826 committed
1126
        runner.test_backward()