test_fused_attn.py 47.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
from transformer_engine.jax import autocast
26
from transformer_engine.jax.sharding import MeshResource
27
28
29
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
30
    AttnSoftmaxType,
31
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
32
    QKVFormat,
33
34
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
35
    fused_attn,
36
    run_length_fill,
37
    make_swa_mask,
38
    SequenceDescriptor,
39
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
40
    ReorderStrategy,
41
)
42
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
43
from transformer_engine_jax import (
44
45
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
46
    get_device_compute_capability,
47
)
48

49
50
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
51

52

53
@pytest.fixture(autouse=True, scope="module")
54
def init():
55
    """
56
    WAR for CUDA uninitialize error
57
    """
zlsh80826's avatar
zlsh80826 committed
58
59
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
60
61
62
    yield


63
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
64
65
66
67
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
68
    softmax_offset: Optional[ArrayLike],
69
70
71
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
72
    softmax_type: AttnSoftmaxType,
73
74
75
76
77
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
78
    """
zlsh80826's avatar
zlsh80826 committed
79
    Similar to flax.linen.dot_product_attention but with GQA support
80
    """
zlsh80826's avatar
zlsh80826 committed
81
82
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
83

zlsh80826's avatar
zlsh80826 committed
84
    b, s_q, h_q, d = query.shape
85
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
86
87
88
89
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
90
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
91
92

    if bias is not None:
93
94
95
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
96
        logits = logits + bias
97
98
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
99
100
101
102

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
103
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    match softmax_type:
        case AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_out = jax.nn.softmax(logits).astype(dtype)
        case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            # Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
            # Append a zero logit, apply standard softmax, then remove last column
            zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
            logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # Append learnable offset logit, apply standard softmax, then remove last column
            learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
            learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
            logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case _:
            raise NotImplementedError(f"Unknown {softmax_type=}")
124

125
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
126
127
128
129
130
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

131
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
132
133
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
134
    return context
135
136


137
138
139
140
141
142
143
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
144
    """
145
146
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
147
    If segment_pos is not provided, aragne of the segment_ids will be applied.
148
    """
149
150
151
152
153
154
155
156
157
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
158
    return inv_causal_mask
159

160

161
@partial(jax.jit, static_argnums=(4, 5))
162
def make_mask(
163
164
165
166
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
167
    attn_mask_type: AttnMaskType,
168
    window_size: Optional[Tuple[int, int]] = None,
169
) -> Array:
170
171
172
173
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
174
175
176
177
178
179
180
181

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
182
    """
183
    # segment masks
184
    inv_mask = make_attention_mask(
185
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
186
    )
187
188
189
190
191
192
193
194
195
196

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

197
198
199
200
201
202
203
204
205
206
    if attn_mask_type.is_bottom_right():
        run_length_out_q = run_length_fill(segment_ids_q)
        run_length_out_kv = run_length_fill(segment_ids_kv)
        bottom_right_causal_mask = make_attention_mask(
            run_length_out_q - segment_pos_q,
            run_length_out_kv - segment_pos_kv,
            jnp.less_equal,
        )
        inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
    elif attn_mask_type.is_causal():
207
208
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
209
        )
210
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
211

212
    # sliding window mask
213
214
215
216
217
218
219
220
221
222
223
224
    inv_swa_mask = (
        make_swa_mask(
            segment_pos_q,
            segment_pos_kv,
            window_size,
            dtype=jnp.bool,
            segment_ids_q=segment_ids_q,
            segment_ids_kv=segment_ids_kv,
        )
        if attn_mask_type.is_bottom_right()
        else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
    )
225
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
226
227
    mask = jnp.logical_not(inv_mask)
    return mask
228

229

230
231
@jax.jit
def get_seqlens_and_offsets(segment_ids):
232
233
234
235
236
237
238
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
239
        first_column = x[..., :1] != 0
240
241
242
243
244
245
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
246
247
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
248
249
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
250
251
252
253
254
255
256
257
258
259
260
261


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


262
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
263
    """
zlsh80826's avatar
zlsh80826 committed
264
    JAX native dot product attention implementation
265
    """
266
267
268
269
    output = general_dot_product_attention(
        query,
        key,
        value,
270
        softmax_offset,
271
272
        bias,
        mask,
273
274
275
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
276
        softmax_type=kwargs["softmax_type"],
277
278
279
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
280
    return output.astype(query.dtype)
281
282


283
284
285
286
287
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
288
    softmax_offset,
289
    sequence_descriptor,
290
291
292
    dropout_rng,
    **kwargs,
):
293
    """
zlsh80826's avatar
zlsh80826 committed
294
    TE customcall dot product attention implementation
295
    """
296
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
297
    match qkv_layout:
298
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
299
300
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
301
302
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
303
304
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
305
306
307
308
309
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
310
311
312
    return fused_attn(
        qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
313
314


315
class BiasShape(Enum):
316
317
318
319
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

320
321
322
323
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
324
325


326
327
328
329
330
331
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
332
333
@dataclass
class FusedAttnRunner:
334
    """
zlsh80826's avatar
zlsh80826 committed
335
    Fused attention runner
336
    """
337

zlsh80826's avatar
zlsh80826 committed
338
339
340
341
342
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
343
344
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
345
346
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
347
    softmax_type: AttnSoftmaxType
zlsh80826's avatar
zlsh80826 committed
348
349
350
351
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
352
    bias_shape: BiasShape
353
354
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
zlsh80826's avatar
zlsh80826 committed
355

356
357
358
359
360
361
362
363
364
365
366
367
368
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

369
370
371
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
372
373
374
375
376
377
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
378
        else:
379
            return 1
380

zlsh80826's avatar
zlsh80826 committed
381
    def _check_configs(self):
382
        # TODO(rewang): probably adds this in is_fused_attn_available
383
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
384
385
            pytest.skip("THD format requires padding masks.")

386
387
388
389
390
391
392
393
394
395
        if self.attn_mask_type.is_bottom_right():
            if self.max_seqlen_q > self.max_seqlen_kv:
                pytest.skip(
                    f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
                )
            if self.attn_bias_type is not AttnBiasType.NO_BIAS:
                pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
            if self.dropout_prob != 0.0:
                pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")

396
        if self.qkv_layout.is_qkvpacked():
397
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
398
399
400
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
401

402
403
404
405
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )
406
        # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
407
        if (
408
            get_device_compute_capability(0) >= 100
409
410
411
412
            and self.dropout_prob == 0.1
            and self.attn_bias_type is not AttnBiasType.NO_BIAS
        ):
            pytest.skip(
413
                "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
414
            )
415
416
417
418
419
420
421
422
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

423
        self.backend = FusedAttnHelper(
424
            self.is_training,
425
426
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
427
428
429
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
430
            self.softmax_type,
431
432
433
434
435
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
436
437
            self.head_dim_qk,
            self.head_dim_v,
438
            (-1, -1) if self.window_size is None else self.window_size,
439
        ).get_fused_attn_backend()
440
        if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
zlsh80826's avatar
zlsh80826 committed
441
            pytest.skip("Unsupported inputs combination or device compute capability.")
442

443
444
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
445
            and self.bias_shape != BiasShape._1HSS
446
        ):
447
            if self.attn_mask_type.is_padding():
448
                pytest.skip(
449
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
450
                )
451
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
452
453
454
455
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
456

zlsh80826's avatar
zlsh80826 committed
457
458
    def _setup_inputs(self):
        self._check_configs()
459
460
461
462
463
464

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
465
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
466

zlsh80826's avatar
zlsh80826 committed
467
        key = jax.random.PRNGKey(0)
468
        q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
469

470
471
472
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
473

474
475
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
476
        elif self.bias_shape == BiasShape._1HSS:
477
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
478
        elif self.bias_shape == BiasShape._B1SS:
479
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
480
        elif self.bias_shape == BiasShape._BHSS:
481
482
483
484
485
486
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
487
        elif self.bias_shape == BiasShape._11SS:
488
489
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
490
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
491

492
493
494
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
495
496

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
497
            if self.bias_shape == BiasShape._1HSS:
498
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
499
500
501
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
502
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
503
504
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
505
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
506
507
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
508
509
510
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
511
512
        else:
            self.bias = None
513

514
        if self.attn_mask_type.is_padding():
515
            pad_ratio = 0.3
516
517
        else:
            pad_ratio = 0.0
518

519
520
521
522
523
524
525
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            self.softmax_offset = jax.random.uniform(
                softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
            )
        else:
            self.softmax_offset = None

zlsh80826's avatar
zlsh80826 committed
526
527
528
529
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
530
531
532
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
533
534
535
536
537
538
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
539
540
541
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
542
543
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
544
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
545
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
546
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
547
548
549
550
551
552
553

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

554
555
556
557
558
559
560
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
561
562
563
564
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
565
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
566
                    if with_segment_pad:
567
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
568
569
570
571
572
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

573
574
575
576
577
578
579
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
580
            self.num_segments_per_seq = 2
581
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
582
583
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
584
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
585
586
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
587
588
589
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
590
            else:
591
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
592
593
594
595
596
                min_segment_len = None
                if (
                    self.window_size is not None or self.attn_mask_type.is_bottom_right()
                ):  # SWA or BRCM requires kv_len >= q_len
                    min_segment_len = self.seqlens_q
597
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
598
599
600
601
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
602
                    min_segment_len=min_segment_len,
603
                )
604
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
605
606
        else:
            self.num_segments_per_seq = 1
607
608
609
610
611
612
613
614
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
615

616
        # For reference code
617
        self.mask = make_mask(
618
619
620
621
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
622
            self.attn_mask_type,
623
            self.window_size,
624
        )
625

Reese Wang's avatar
Reese Wang committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

650
        # Test different input formats
651
        if self.qkv_layout.is_thd():
652
653
654
655
656
657
658
659
660
661
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
662
663
664
665
666
667
668
669
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
                            self.cp_reorder_fn(self.segment_pos_q),
                            self.cp_reorder_fn(self.segment_pos_kv),
                        ),
670
671
672
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
673
        else:
674
675
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
676
677
678
679
680
681
682
683
684
685
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
700

zlsh80826's avatar
zlsh80826 committed
701
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
702
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
703

704
705
706
707
708
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
709
            self.mesh_resource.tpsp_resource,
710
711
712
713
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

714
        mask_pspec = PartitionSpec(
715
716
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
717
718
719
720
721
722
723
724
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
725
726
727
728
729
730
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
731
732
733
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
734
735
736

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
737
                None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
738
739
740
741
742
743
744
745
746
747
748
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

749
750
751
752
753
754
755
756
757
758
        # Softmax offset sharding (1, num_heads, 1, 1)
        # Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
        head_resource = (
            self.mesh_resource.tpsp_resource
            if self.mesh_resource.tpsp_resource is not None
            else self.mesh_resource.tp_resource
        )
        self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
        self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)

759
760
761
762
763
764
765
766
767
768
769
770
771
        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
772
773
774
775
776
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
777

778
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
779

780
        customcall_args = [
781
782
783
784
785
786
787
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
788
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
789
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
790
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
791
        ]
zlsh80826's avatar
zlsh80826 committed
792
        kwargs = {
793
794
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
795
            "softmax_type": self.softmax_type,
796
797
798
799
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
800
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
801
            "window_size": self.window_size,
802
803
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
804
        }
805

806
807
808
809
810
811
812
813
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
814
                self.softmax_offset_sharding,
815
                self.seq_desc_sharding,
816
817
818
819
                self.dropout_rng_sharding,
            ],
        )

820
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
821
822
823
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

824
        reference_out = jax_dpa(*args, **kwargs)
825

826
        if self.is_training and self.dropout_prob > 0.0:
827
828
            return

829
830
831
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
832
833
834

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
835

836
        if self.coll_count_ref is not None:
837
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
838
839
840
841
842
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
843
    def test_backward(self):
844
        """
845
846
847
848
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
849
        """
zlsh80826's avatar
zlsh80826 committed
850
851
852

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
853
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
854
            # Gradient is small, use a gradient multiplier to amplify the gradient
855
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
856
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
857
                gradient_multiplier /= 10
858
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
859
860
861
862
863
864
865
866
867
868
869
870
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
871
872
873
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
874

875
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
876
        customcall_args = [
877
878
879
880
881
882
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
883
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
884
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
885
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
886
        ]
887
        kwargs = {
888
889
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
890
            "softmax_type": self.softmax_type,
891
892
893
894
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
895
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
896
            "window_size": self.window_size,
897
898
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
899
900
        }

901
        # We can compute dBias only for the [1, h, s, s] layout
902
903
904
905
906
907
908
909
910
911
912
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
913

914
915
916
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
917
918
919
920
921
922
923
924
925
926
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    customcall_fused_dpa,
                    q,
                    k,
                    v,
                    bias,
                    softmax_offset,
                    *args,
                    cp_reverse_out=True,
                    **kwargs,
927
928
                ),
                arg_nums,
929
930
931
932
933
934
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
935
                self.softmax_offset_sharding,
936
                self.seq_desc_sharding,
937
938
939
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
940
        )
941
942
        jitted_reference = jit(
            value_and_grad(
943
944
945
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
                ),
946
947
948
                arg_nums,
            )
        )
949

950
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
951
952
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
953
        reference_out, reference_dgrad = jitted_reference(*args)
954

zlsh80826's avatar
zlsh80826 committed
955
        # Skip elementwise comparison when dropout enabled
956
        if self.dropout_prob > 0.0:
957
958
            return

959
960
961
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
962
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
963

964
        def check_dqkv(primitive, reference, pad, idx):
965
966
967
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
968

969
970
971
972
973
974
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

975
976
977
978
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

979
980
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
981

982
983
984
985
986
987
988
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
989

990
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
991
992
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

993
994
995
996
997
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
998

999
            # Assert all masked dbias are 0s
1000
            assert_allclose(
1001
1002
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
1003
1004
                dtype=self.dtype,
            )
1005

1006
            # dbias padded part
1007
            assert_allclose(
1008
1009
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
1010
1011
                dtype=self.dtype,
            )
1012

1013
            # dbias valid part
1014
            assert_allclose(
1015
1016
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
1017
1018
1019
                dtype=self.dtype,
            )

1020
        if self.coll_count_ref is not None:
1021
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
1022
1023
1024
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

1025
1026
1027
1028
1029
1030
1031
1032

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
1033
1034
1035
        pytest.param(
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
        ),
1036
1037
    ],
)
1038
1039
1040
1041
1042
1043
1044
1045
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
        pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
        pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
    ],
)
1046
1047
1048
1049
1050
1051
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
1052
1053
1054
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
1055
1056
1057
    ],
)
@pytest.mark.parametrize(
1058
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype",
1059
    [
1060
1061
1062
1063
1064
        pytest.param(
            2, 2048, 2048, 12, 12, 64, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-64-BF16-SELF"
        ),
        pytest.param(
            2,
1065
            512,
1066
1067
1068
1069
1070
1071
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
1072
            id="2-512-1024-12-12-64-64-BF16-CROSS",
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 64, jnp.float16, id="4-128-128-16-16-64-64-FP16-SELF"
        ),
        pytest.param(
            4, 128, 128, 16, 16, 64, 32, jnp.float16, id="4-128-128-16-16-64-32-FP16-SELF"
        ),
1083
1084
1085
1086
1087
1088
1089
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
1090
            32,
1091
            jnp.bfloat16,
1092
1093
1094
1095
            id="2-2048-1024-12-12-64-32-BF16-CROSS",
        ),
        pytest.param(
            2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
1096
        ),
1097
1098
1099
1100
1101
1102
1103
1104
1105
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
1106
1107
1108
1109
1110
1111
1112
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1113
1114
1115
1116
1117
1118
1119
1120
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
1121
1122
1123
1124
class TestFusedAttn:
    """
    Fused attention tester
    """
1125

zlsh80826's avatar
zlsh80826 committed
1126
    @staticmethod
1127
1128
1129
1130
1131
1132
1133
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1144
    def _test_forward(
1145
1146
1147
1148
1149
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1150
1151
        d_qk,
        d_v,
1152
1153
        attn_bias_type,
        attn_mask_type,
1154
        softmax_type,
1155
1156
1157
1158
1159
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1160
        swa,
1161
        seq_desc_format,
1162
    ):
1163
        """
zlsh80826's avatar
zlsh80826 committed
1164
        Test forward with parameterized configs
1165
1166
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1167
        """
1168
1169
1170
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1171
1172
1173
1174
1175
1176
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1177
1178
            d_qk,
            d_v,
1179
1180
            attn_bias_type,
            attn_mask_type,
1181
            softmax_type,
1182
1183
1184
1185
1186
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1187
            window_size,
1188
            seq_desc_format,
1189
        )
zlsh80826's avatar
zlsh80826 committed
1190
        runner.test_forward()
1191

zlsh80826's avatar
zlsh80826 committed
1192
    @staticmethod
1193
1194
1195
1196
1197
1198
1199
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1200
1201
1202
1203
1204
1205
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1206
1207
        d_qk,
        d_v,
1208
1209
        attn_bias_type,
        attn_mask_type,
1210
        softmax_type,
1211
1212
1213
1214
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1215
        swa,
1216
        seq_desc_format,
1217
    ):
zlsh80826's avatar
zlsh80826 committed
1218
1219
1220
        """
        Test backward with parameterized configs
        """
1221
1222
1223
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1224
1225
1226
1227
1228
1229
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1230
1231
            d_qk,
            d_v,
1232
1233
            attn_bias_type,
            attn_mask_type,
1234
            softmax_type,
1235
1236
1237
1238
1239
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1240
            window_size,
1241
            seq_desc_format,
1242
        )
zlsh80826's avatar
zlsh80826 committed
1243
        runner.test_backward()