test_fused_attn.py 43.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
from transformer_engine.jax import autocast
26
from transformer_engine.jax.sharding import MeshResource
27
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
31
    QKVFormat,
32
33
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
34
    fused_attn,
35
    run_length_fill,
36
    make_swa_mask,
37
    SequenceDescriptor,
38
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
39
    ReorderStrategy,
40
)
41
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
42
from transformer_engine_jax import (
43
44
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
45
    get_device_compute_capability,
46
)
47

48
49
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
50

51

52
@pytest.fixture(autouse=True, scope="module")
53
def init():
54
    """
55
    WAR for CUDA uninitialize error
56
    """
zlsh80826's avatar
zlsh80826 committed
57
58
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
59
60
61
    yield


62
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
63
64
65
66
67
68
69
70
71
72
73
74
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
75
    """
zlsh80826's avatar
zlsh80826 committed
76
    Similar to flax.linen.dot_product_attention but with GQA support
77
    """
zlsh80826's avatar
zlsh80826 committed
78
79
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
80

zlsh80826's avatar
zlsh80826 committed
81
    b, s_q, h_q, d = query.shape
82
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
83
84
85
86
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
87
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
88
89

    if bias is not None:
90
91
92
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
93
        logits = logits + bias
94
95
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
96
97
98
99

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
100
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
101

zlsh80826's avatar
zlsh80826 committed
102
    softmax_out = jax.nn.softmax(logits).astype(dtype)
103

104
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
105
106
107
108
109
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

110
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
111
112
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
113
    return context
114
115


116
117
118
119
120
121
122
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
123
    """
124
125
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
126
    If segment_pos is not provided, aragne of the segment_ids will be applied.
127
    """
128
129
130
131
132
133
134
135
136
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
137
    return inv_causal_mask
138

139

140
@partial(jax.jit, static_argnums=(4, 5))
141
def make_mask(
142
143
144
145
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
146
    attn_mask_type: AttnMaskType,
147
    window_size: Optional[Tuple[int, int]] = None,
148
) -> Array:
149
150
151
152
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
153
154
155
156
157
158
159
160

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
161
    """
162
    # segment masks
163
    inv_mask = make_attention_mask(
164
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
165
    )
166
167
168
169
170
171
172
173
174
175

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

176
177
178
179
180
181
182
183
184
185
    if attn_mask_type.is_bottom_right():
        run_length_out_q = run_length_fill(segment_ids_q)
        run_length_out_kv = run_length_fill(segment_ids_kv)
        bottom_right_causal_mask = make_attention_mask(
            run_length_out_q - segment_pos_q,
            run_length_out_kv - segment_pos_kv,
            jnp.less_equal,
        )
        inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
    elif attn_mask_type.is_causal():
186
187
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
188
        )
189
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
190

191
    # sliding window mask
192
193
194
195
196
197
198
199
200
201
202
203
    inv_swa_mask = (
        make_swa_mask(
            segment_pos_q,
            segment_pos_kv,
            window_size,
            dtype=jnp.bool,
            segment_ids_q=segment_ids_q,
            segment_ids_kv=segment_ids_kv,
        )
        if attn_mask_type.is_bottom_right()
        else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
    )
204
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
205
206
    mask = jnp.logical_not(inv_mask)
    return mask
207

208

209
210
@jax.jit
def get_seqlens_and_offsets(segment_ids):
211
212
213
214
215
216
217
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
218
        first_column = x[..., :1] != 0
219
220
221
222
223
224
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
225
226
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
227
228
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
229
230
231
232
233
234
235
236
237
238
239
240
241


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
242
    """
zlsh80826's avatar
zlsh80826 committed
243
    JAX native dot product attention implementation
244
    """
245
246
247
248
    output = general_dot_product_attention(
        query,
        key,
        value,
249
250
        bias,
        mask,
251
252
253
254
255
256
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
257
    return output.astype(query.dtype)
258
259


260
261
262
263
264
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
265
    sequence_descriptor,
266
267
268
    dropout_rng,
    **kwargs,
):
269
    """
zlsh80826's avatar
zlsh80826 committed
270
    TE customcall dot product attention implementation
271
    """
272
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
273
    match qkv_layout:
274
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
275
276
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
277
278
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
279
280
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
281
282
283
284
285
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
286
287
288
    return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
        query.dtype
    )
zlsh80826's avatar
zlsh80826 committed
289
290


291
class BiasShape(Enum):
292
293
294
295
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

296
297
298
299
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
300
301


302
303
304
305
306
307
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
308
309
@dataclass
class FusedAttnRunner:
310
    """
zlsh80826's avatar
zlsh80826 committed
311
    Fused attention runner
312
    """
313

zlsh80826's avatar
zlsh80826 committed
314
315
316
317
318
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
319
320
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
321
322
323
324
325
326
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
327
    bias_shape: BiasShape
328
329
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
zlsh80826's avatar
zlsh80826 committed
330

331
332
333
334
335
336
337
338
339
340
341
342
343
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

344
345
346
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
347
348
349
350
351
352
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
353
        else:
354
            return 1
355

zlsh80826's avatar
zlsh80826 committed
356
    def _check_configs(self):
357
        # TODO(rewang): probably adds this in is_fused_attn_available
358
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
359
360
            pytest.skip("THD format requires padding masks.")

361
362
363
364
365
366
367
368
369
370
        if self.attn_mask_type.is_bottom_right():
            if self.max_seqlen_q > self.max_seqlen_kv:
                pytest.skip(
                    f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
                )
            if self.attn_bias_type is not AttnBiasType.NO_BIAS:
                pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
            if self.dropout_prob != 0.0:
                pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")

371
        if self.qkv_layout.is_qkvpacked():
372
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
373
374
375
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
376

377
378
379
380
381
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

382
383
384
385
386
387
388
389
        if (
            get_device_compute_capability(0) == 100
            and self.dropout_prob == 0.1
            and self.attn_bias_type is not AttnBiasType.NO_BIAS
        ):
            pytest.skip(
                "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
            )
390
391
392
393
394
395
396
397
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

398
        self.backend = FusedAttnHelper(
399
            self.is_training,
400
401
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
402
403
404
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
405
406
407
408
409
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
410
411
            self.head_dim_qk,
            self.head_dim_v,
412
            (-1, -1) if self.window_size is None else self.window_size,
413
        ).get_fused_attn_backend()
414
        if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
zlsh80826's avatar
zlsh80826 committed
415
            pytest.skip("Unsupported inputs combination or device compute capability.")
416

417
418
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
419
            and self.bias_shape != BiasShape._1HSS
420
        ):
421
            if self.attn_mask_type.is_padding():
422
                pytest.skip(
423
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
424
                )
425
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
426
427
428
429
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
430

zlsh80826's avatar
zlsh80826 committed
431
432
    def _setup_inputs(self):
        self._check_configs()
433
434
435
436
437
438

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
439
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
440

zlsh80826's avatar
zlsh80826 committed
441
442
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
443

444
445
446
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
447

448
449
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
450
        elif self.bias_shape == BiasShape._1HSS:
451
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
452
        elif self.bias_shape == BiasShape._B1SS:
453
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
454
        elif self.bias_shape == BiasShape._BHSS:
455
456
457
458
459
460
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
461
        elif self.bias_shape == BiasShape._11SS:
462
463
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
464
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
465

466
467
468
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
469
470

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
471
            if self.bias_shape == BiasShape._1HSS:
472
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
473
474
475
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
476
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
477
478
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
479
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
480
481
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
482
483
484
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
485
486
        else:
            self.bias = None
487

488
        if self.attn_mask_type.is_padding():
489
            pad_ratio = 0.3
490
491
        else:
            pad_ratio = 0.0
492

zlsh80826's avatar
zlsh80826 committed
493
494
495
496
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
497
498
499
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
500
501
502
503
504
505
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
506
507
508
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
509
510
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
511
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
512
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
513
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
514
515
516
517
518
519
520

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

521
522
523
524
525
526
527
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
528
529
530
531
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
532
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
533
                    if with_segment_pad:
534
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
535
536
537
538
539
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

540
541
542
543
544
545
546
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
547
            self.num_segments_per_seq = 2
548
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
549
550
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
551
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
552
553
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
554
555
556
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
557
            else:
558
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
559
560
561
562
563
                min_segment_len = None
                if (
                    self.window_size is not None or self.attn_mask_type.is_bottom_right()
                ):  # SWA or BRCM requires kv_len >= q_len
                    min_segment_len = self.seqlens_q
564
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
565
566
567
568
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
569
                    min_segment_len=min_segment_len,
570
                )
571
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
572
573
        else:
            self.num_segments_per_seq = 1
574
575
576
577
578
579
580
581
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
582

583
        # For reference code
584
        self.mask = make_mask(
585
586
587
588
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
589
            self.attn_mask_type,
590
            self.window_size,
591
        )
592

Reese Wang's avatar
Reese Wang committed
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

617
        # Test different input formats
618
        if self.qkv_layout.is_thd():
619
620
621
622
623
624
625
626
627
628
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
629
630
631
632
633
634
635
636
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
                            self.cp_reorder_fn(self.segment_pos_q),
                            self.cp_reorder_fn(self.segment_pos_kv),
                        ),
637
638
639
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
640
        else:
641
642
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
643
644
645
646
647
648
649
650
651
652
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
667

zlsh80826's avatar
zlsh80826 committed
668
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
669
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
670

671
672
673
674
675
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
676
            self.mesh_resource.tpsp_resource,
677
678
679
680
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

681
        mask_pspec = PartitionSpec(
682
683
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
684
685
686
687
688
689
690
691
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
692
693
694
695
696
697
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
698
699
700
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
701
702
703

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
704
                None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
729
730
731
732
733
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
734

735
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
736

737
        customcall_args = [
738
739
740
741
742
743
744
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
745
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
746
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
747
        ]
zlsh80826's avatar
zlsh80826 committed
748
        kwargs = {
749
750
751
752
753
754
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
755
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
756
            "window_size": self.window_size,
757
758
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
759
        }
760

761
762
763
764
765
766
767
768
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
769
                self.seq_desc_sharding,
770
771
772
773
                self.dropout_rng_sharding,
            ],
        )

774
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
775
776
777
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

778
        reference_out = jax_dpa(*args, **kwargs)
779

780
        if self.is_training and self.dropout_prob > 0.0:
781
782
            return

783
784
785
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
786
787
788

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
789

790
        if self.coll_count_ref is not None:
791
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
792
793
794
795
796
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
797
    def test_backward(self):
798
        """
799
800
801
802
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
803
        """
zlsh80826's avatar
zlsh80826 committed
804
805
806

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
807
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
808
            # Gradient is small, use a gradient multiplier to amplify the gradient
809
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
810
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
811
                gradient_multiplier /= 10
812
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
813
814
815
816
817
818
819
820
821
822
823
824
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
825
826
827
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
828

829
830
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
831
832
833
834
835
836
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
837
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
838
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
839
        ]
840
        kwargs = {
841
842
843
844
845
846
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
847
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
848
            "window_size": self.window_size,
849
850
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
851
852
        }

853
        # We can compute dBias only for the [1, h, s, s] layout
854
855
856
857
858
859
860
861
862
863
864
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
865

866
867
868
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
869
                lambda q, k, v, bias, *args: grad_func(
Reese Wang's avatar
Reese Wang committed
870
                    customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
871
872
                ),
                arg_nums,
873
874
875
876
877
878
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
879
                self.seq_desc_sharding,
880
881
882
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
883
        )
884
885
        jitted_reference = jit(
            value_and_grad(
886
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
887
888
889
                arg_nums,
            )
        )
890

891
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
892
893
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
894
        reference_out, reference_dgrad = jitted_reference(*args)
895

zlsh80826's avatar
zlsh80826 committed
896
        # Skip elementwise comparison when dropout enabled
897
        if self.dropout_prob > 0.0:
898
899
            return

900
901
902
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
903
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
904

905
        def check_dqkv(primitive, reference, pad, idx):
906
907
908
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
909

910
911
912
913
914
915
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

916
917
918
919
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

920
921
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
922

923
924
925
926
927
928
929
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
930

931
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
932
933
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

934
935
936
937
938
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
939

940
            # Assert all masked dbias are 0s
941
            assert_allclose(
942
943
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
944
945
                dtype=self.dtype,
            )
946

947
            # dbias padded part
948
            assert_allclose(
949
950
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
951
952
                dtype=self.dtype,
            )
953

954
            # dbias valid part
955
            assert_allclose(
956
957
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
958
959
960
                dtype=self.dtype,
            )

961
        if self.coll_count_ref is not None:
962
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
963
964
965
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

966
967
968
969
970
971
972
973

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
974
975
976
        pytest.param(
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
        ),
977
978
979
980
981
982
983
984
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
985
986
987
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
988
989
990
    ],
)
@pytest.mark.parametrize(
991
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
992
    [
993
994
995
996
997
        pytest.param(
            2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
        ),
        pytest.param(
            2,
998
            512,
999
1000
1001
1002
1003
1004
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
1005
            id="2-512-1024-12-12-64-64-BF16-CROSS",
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
        ),
1016
1017
1018
1019
1020
1021
1022
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
1023
            32,
1024
            jnp.bfloat16,
1025
1026
1027
1028
            id="2-2048-1024-12-12-64-32-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
1029
        ),
1030
1031
1032
1033
1034
1035
1036
1037
1038
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
1039
1040
1041
1042
1043
1044
1045
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1046
1047
1048
1049
1050
1051
1052
1053
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
1054
1055
1056
1057
class TestFusedAttn:
    """
    Fused attention tester
    """
1058

zlsh80826's avatar
zlsh80826 committed
1059
    @staticmethod
1060
1061
1062
1063
1064
1065
1066
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1077
    def _test_forward(
1078
1079
1080
1081
1082
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1083
1084
        d_qk,
        d_v,
1085
1086
1087
1088
1089
1090
1091
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1092
        swa,
1093
        seq_desc_format,
1094
    ):
1095
        """
zlsh80826's avatar
zlsh80826 committed
1096
        Test forward with parameterized configs
1097
1098
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1099
        """
1100
1101
1102
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1103
1104
1105
1106
1107
1108
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1109
1110
            d_qk,
            d_v,
1111
1112
1113
1114
1115
1116
1117
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1118
            window_size,
1119
            seq_desc_format,
1120
        )
zlsh80826's avatar
zlsh80826 committed
1121
        runner.test_forward()
1122

zlsh80826's avatar
zlsh80826 committed
1123
    @staticmethod
1124
1125
1126
1127
1128
1129
1130
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1131
1132
1133
1134
1135
1136
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1137
1138
        d_qk,
        d_v,
1139
1140
1141
1142
1143
1144
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1145
        swa,
1146
        seq_desc_format,
1147
    ):
zlsh80826's avatar
zlsh80826 committed
1148
1149
1150
        """
        Test backward with parameterized configs
        """
1151
1152
1153
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1154
1155
1156
1157
1158
1159
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1160
1161
            d_qk,
            d_v,
1162
1163
1164
1165
1166
1167
1168
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1169
            window_size,
1170
            seq_desc_format,
1171
        )
zlsh80826's avatar
zlsh80826 committed
1172
        runner.test_backward()