test_fused_attn.py 51 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
from transformer_engine.jax import autocast
26
from transformer_engine.jax.sharding import MeshResource
27
28
29
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
30
    AttnSoftmaxType,
31
    QKVLayout,
Reese Wang's avatar
Reese Wang committed
32
    QKVFormat,
33
34
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
35
    fused_attn,
36
    run_length_fill,
37
    make_swa_mask,
38
    SequenceDescriptor,
39
    CPStrategy,
Reese Wang's avatar
Reese Wang committed
40
    ReorderStrategy,
41
)
42
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
43
from transformer_engine_jax import (
44
45
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
46
    get_device_compute_capability,
47
)
48

49
50
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
51

52

53
@pytest.fixture(autouse=True, scope="module")
54
def init():
55
    """
56
    WAR for CUDA uninitialize error
57
    """
zlsh80826's avatar
zlsh80826 committed
58
59
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
60
61
62
    yield


63
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
64
65
66
67
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
68
    softmax_offset: Optional[ArrayLike],
69
70
71
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
72
    softmax_type: AttnSoftmaxType,
73
74
75
76
77
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
78
    """
zlsh80826's avatar
zlsh80826 committed
79
    Similar to flax.linen.dot_product_attention but with GQA support
80
    """
zlsh80826's avatar
zlsh80826 committed
81
82
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
83

zlsh80826's avatar
zlsh80826 committed
84
    b, s_q, h_q, d = query.shape
85
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
86
87
88
89
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
90
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
91
92

    if bias is not None:
93
94
95
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
96
        logits = logits + bias
97
98
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
99
100
101
102

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
103
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    match softmax_type:
        case AttnSoftmaxType.VANILLA_SOFTMAX:
            softmax_out = jax.nn.softmax(logits).astype(dtype)
        case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
            # Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
            # Append a zero logit, apply standard softmax, then remove last column
            zero_logit = jnp.zeros(logits.shape[:-1] + (1,), dtype=logits.dtype)
            logits_with_extra = jnp.concatenate([logits, zero_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case AttnSoftmaxType.LEARNABLE_SOFTMAX:
            # Append learnable offset logit, apply standard softmax, then remove last column
            learnable_logit = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
            learnable_logit = jnp.broadcast_to(learnable_logit, logits.shape[:-1] + (1,))
            logits_with_extra = jnp.concatenate([logits, learnable_logit], axis=-1)
            softmax_with_extra = jax.nn.softmax(logits_with_extra, axis=-1)
            softmax_out = softmax_with_extra[..., :-1].astype(dtype)
        case _:
            raise NotImplementedError(f"Unknown {softmax_type=}")
124

125
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
126
127
128
129
130
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

131
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
132
133
    context_shape = query.shape[:-1] + (value.shape[-1],)
    context = jnp.reshape(context, context_shape)
zlsh80826's avatar
zlsh80826 committed
134
    return context
135
136


137
138
139
140
141
142
143
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
144
    """
145
146
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
147
    If segment_pos is not provided, aragne of the segment_ids will be applied.
148
    """
149
150
151
152
153
154
155
156
157
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
158
    return inv_causal_mask
159

160

161
@partial(jax.jit, static_argnums=(4, 5))
162
def make_mask(
163
164
165
166
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
167
    attn_mask_type: AttnMaskType,
168
    window_size: Optional[Tuple[int, int]] = None,
169
) -> Array:
170
171
172
173
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
174
175
176
177
178
179
180
181

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
182
    """
183
    # segment masks
184
    inv_mask = make_attention_mask(
185
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
186
    )
187
188
189
190
191
192
193
194
195
196

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

197
198
199
200
201
202
203
204
205
206
    if attn_mask_type.is_bottom_right():
        run_length_out_q = run_length_fill(segment_ids_q)
        run_length_out_kv = run_length_fill(segment_ids_kv)
        bottom_right_causal_mask = make_attention_mask(
            run_length_out_q - segment_pos_q,
            run_length_out_kv - segment_pos_kv,
            jnp.less_equal,
        )
        inv_mask = combine_masks(bottom_right_causal_mask, inv_mask)
    elif attn_mask_type.is_causal():
207
208
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
209
        )
210
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
211

212
    # sliding window mask
213
214
215
216
217
218
219
220
221
222
223
224
    inv_swa_mask = (
        make_swa_mask(
            segment_pos_q,
            segment_pos_kv,
            window_size,
            dtype=jnp.bool,
            segment_ids_q=segment_ids_q,
            segment_ids_kv=segment_ids_kv,
        )
        if attn_mask_type.is_bottom_right()
        else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_)
    )
225
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
226
227
    mask = jnp.logical_not(inv_mask)
    return mask
228

229

230
231
@jax.jit
def get_seqlens_and_offsets(segment_ids):
232
233
234
235
236
237
238
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
239
        first_column = x[..., :1] != 0
240
241
242
243
244
245
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
246
247
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
248
249
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
250
251
252
253
254
255
256
257
258
259
260
261


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


262
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
263
    """
zlsh80826's avatar
zlsh80826 committed
264
    JAX native dot product attention implementation
265
    """
266
267
268
269
    output = general_dot_product_attention(
        query,
        key,
        value,
270
        softmax_offset,
271
272
        bias,
        mask,
273
274
275
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
276
        softmax_type=kwargs["softmax_type"],
277
278
279
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
280
    return output.astype(query.dtype)
281
282


283
284
285
286
287
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
288
    softmax_offset,
289
    sequence_descriptor,
290
291
292
    dropout_rng,
    **kwargs,
):
293
    """
zlsh80826's avatar
zlsh80826 committed
294
    TE customcall dot product attention implementation
295
    """
296
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
297
    match qkv_layout:
298
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
299
300
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
301
302
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
303
304
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
305
306
307
308
309
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
310
311
312
    return fused_attn(
        qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
313
314


315
class BiasShape(Enum):
316
317
318
319
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

320
321
322
323
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
324
325


326
327
328
329
330
331
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
332
333
@dataclass
class FusedAttnRunner:
334
    """
zlsh80826's avatar
zlsh80826 committed
335
    Fused attention runner
336
    """
337

zlsh80826's avatar
zlsh80826 committed
338
339
340
341
342
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
343
344
    head_dim_qk: int
    head_dim_v: int
zlsh80826's avatar
zlsh80826 committed
345
346
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
347
    softmax_type: AttnSoftmaxType
zlsh80826's avatar
zlsh80826 committed
348
349
350
351
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
352
    bias_shape: BiasShape
353
354
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
355
356
    stripe_size: int | None = None
    num_segments_per_seq: int | None = None
zlsh80826's avatar
zlsh80826 committed
357

358
359
360
361
362
363
364
365
366
367
368
369
370
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

371
372
373
374
375
376
377
378
    def __post_init__(self):
        # Reset defaults for num_segments_per_seq if not explicitly passed
        if self.num_segments_per_seq is None:
            if self.qkv_layout.is_thd():
                self.num_segments_per_seq = 2
            else:
                self.num_segments_per_seq = 1

379
380
381
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
382
383
384
385
386
387
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
388
        else:
389
            return 1
390

zlsh80826's avatar
zlsh80826 committed
391
    def _check_configs(self):
392
        # TODO(rewang): probably adds this in is_fused_attn_available
393
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
394
395
            pytest.skip("THD format requires padding masks.")

396
397
398
399
400
401
402
403
404
405
        if self.attn_mask_type.is_bottom_right():
            if self.max_seqlen_q > self.max_seqlen_kv:
                pytest.skip(
                    f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q"
                )
            if self.attn_bias_type is not AttnBiasType.NO_BIAS:
                pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM")
            if self.dropout_prob != 0.0:
                pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM")

406
        if self.qkv_layout.is_qkvpacked():
407
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
408
409
410
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
411

412
413
414
415
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )
416
        # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
417
        if (
418
            get_device_compute_capability(0) >= 100
419
420
421
422
            and self.dropout_prob == 0.1
            and self.attn_bias_type is not AttnBiasType.NO_BIAS
        ):
            pytest.skip(
423
                "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
424
            )
425
426
427
428
429
430
431
432
        # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
        # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
        if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate():
            pytest.skip(
                "For head_dim_qk != head_dim_v, it is necessary that the QKV layout "
                "is either BSHD_BSHD_BSHD or THD_THD_THD"
            )

433
        self.backend = FusedAttnHelper(
434
            self.is_training,
435
436
            self.dtype,
            self.dtype,
Reese Wang's avatar
Reese Wang committed
437
438
439
            self.qkv_layout,
            self.attn_bias_type,
            self.attn_mask_type,
440
            self.softmax_type,
441
442
443
444
445
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
446
447
            self.head_dim_qk,
            self.head_dim_v,
448
            (-1, -1) if self.window_size is None else self.window_size,
449
        ).get_fused_attn_backend()
450
        if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
zlsh80826's avatar
zlsh80826 committed
451
            pytest.skip("Unsupported inputs combination or device compute capability.")
452

453
454
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
455
            and self.bias_shape != BiasShape._1HSS
456
        ):
457
            if self.attn_mask_type.is_padding():
458
                pytest.skip(
459
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
460
                )
461
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
462
463
464
465
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
466

zlsh80826's avatar
zlsh80826 committed
467
468
    def _setup_inputs(self):
        self._check_configs()
469
470
471
472
473
474

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
475
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
476

zlsh80826's avatar
zlsh80826 committed
477
        key = jax.random.PRNGKey(0)
478
        q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
479

480
481
482
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
        k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
        v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_v)
483

484
485
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
486
        elif self.bias_shape == BiasShape._1HSS:
487
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
488
        elif self.bias_shape == BiasShape._B1SS:
489
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
490
        elif self.bias_shape == BiasShape._BHSS:
491
492
493
494
495
496
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
497
        elif self.bias_shape == BiasShape._11SS:
498
499
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
500
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
501

502
503
504
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
505
506

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
507
            if self.bias_shape == BiasShape._1HSS:
508
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
509
510
511
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
512
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
513
514
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
515
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
516
517
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
518
519
520
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
521
522
        else:
            self.bias = None
523

524
        if self.attn_mask_type.is_padding():
525
            pad_ratio = 0.3
526
527
        else:
            pad_ratio = 0.0
528

529
530
531
532
533
534
535
        if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
            self.softmax_offset = jax.random.uniform(
                softmax_key, (1, self.num_heads_q, 1, 1), jnp.float32, -1.0
            )
        else:
            self.softmax_offset = None

zlsh80826's avatar
zlsh80826 committed
536
537
538
539
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
540
541
542
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
543
544
545
546
547
548
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
549
550
551
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
552
553
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
554
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
555
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
556
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
557
558
559
560
561
562
563

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

564
565
566
567
568
569
570
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
571
572
573
574
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
575
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
576
                    if with_segment_pad:
577
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
578
579
580
581
582
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

583
584
585
586
587
588
589
590
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
591
592
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
593
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
Reese Wang's avatar
Reese Wang committed
594
595
            # TODO(rewang): record only self attention and find the reason of cross attention
            if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
596
597
598
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
599
            else:
600
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
601
602
603
604
605
                min_segment_len = None
                if (
                    self.window_size is not None or self.attn_mask_type.is_bottom_right()
                ):  # SWA or BRCM requires kv_len >= q_len
                    min_segment_len = self.seqlens_q
606
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
607
608
609
610
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
611
                    min_segment_len=min_segment_len,
612
                )
613
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
614
        else:
615
616
617
618
619
620
621
622
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
623

624
        # For reference code
625
        self.mask = make_mask(
626
627
628
629
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
630
            self.attn_mask_type,
631
            self.window_size,
632
        )
633

Reese Wang's avatar
Reese Wang committed
634
635
636
637
638
639
640
641
642
643
644
645
        if self.cp_size > 1 and self.cp_load_balanced:
            if self.qkv_layout.is_thd():
                reorder_strategy = ReorderStrategy.Striped
            else:
                reorder_strategy = ReorderStrategy.DualChunkSwap

            seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
646
                stripe_size=self.stripe_size,
Reese Wang's avatar
Reese Wang committed
647
648
649
650
651
652
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                strategy=reorder_strategy,
                cp_size=self.cp_size,
                seq_dim=seq_dim,
653
                stripe_size=self.stripe_size,
Reese Wang's avatar
Reese Wang committed
654
655
656
657
658
659
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

660
        # Test different input formats
661
        if self.qkv_layout.is_thd():
662
663
664
665
666
667
668
669
670
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
671
672
                    # Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
                    # if no CP and load balancing, else explicitly pass the segment_pos
673
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
Reese Wang's avatar
Reese Wang committed
674
675
676
677
678
                        (
                            self.cp_reorder_fn(self.segment_ids_q),
                            self.cp_reorder_fn(self.segment_ids_kv),
                        ),
                        (
679
680
681
682
683
684
685
686
687
688
                            (
                                self.cp_reorder_fn(self.segment_pos_q),
                                self.cp_reorder_fn(self.segment_pos_kv),
                            )
                            if self.cp_size > 1 and self.cp_load_balanced
                            else None
                        ),
                        is_thd=self.qkv_layout.is_thd(),
                        is_segment_ids_reordered=(
                            True if self.cp_size > 1 and self.cp_load_balanced else False
Reese Wang's avatar
Reese Wang committed
689
                        ),
690
691
692
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
693
        else:
694
695
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
696
697
698
699
700
701
702
703
704
705
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
706
707
708
709
710
711
712
713
714
715
716
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
717
718
                        is_thd=self.qkv_layout.is_thd(),
                        is_segment_ids_reordered=False,
719
720
721
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
722

zlsh80826's avatar
zlsh80826 committed
723
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
724
        self.scaling_factor = 1.0 / sqrt(self.head_dim_qk)
725

726
727
728
729
730
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
731
            self.mesh_resource.tpsp_resource,
732
733
734
735
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

736
        mask_pspec = PartitionSpec(
737
738
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
739
740
741
742
743
744
745
746
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
Reese Wang's avatar
Reese Wang committed
747
748
749
750
751
752
                    if x.ndim == 1:
                        pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    else:
                        pspec = PartitionSpec(
                            self.mesh_resource.dp_resource, self.mesh_resource.cp_resource
                        )
753
754
755
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
756
757
758

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
759
                None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None
760
761
762
763
764
765
766
767
768
769
770
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

771
772
773
774
775
776
777
778
779
780
        # Softmax offset sharding (1, num_heads, 1, 1)
        # Use the same logic as HEAD_AXES: tpsp_resource if enabled, else tp_resource
        head_resource = (
            self.mesh_resource.tpsp_resource
            if self.mesh_resource.tpsp_resource is not None
            else self.mesh_resource.tp_resource
        )
        self.softmax_offset_pspec = PartitionSpec(None, head_resource, None, None)
        self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)

781
782
783
784
785
786
787
788
789
790
791
792
793
        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

zlsh80826's avatar
zlsh80826 committed
794
795
    def test_forward(self):
        """
796
        Test forward with JITted primitive and unJITted reference
zlsh80826's avatar
zlsh80826 committed
797
798
        """
        self._setup_inputs()
799

800
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
801

802
        customcall_args = [
803
804
805
806
807
808
809
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
810
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
811
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
812
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
813
        ]
zlsh80826's avatar
zlsh80826 committed
814
        kwargs = {
815
816
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
817
            "softmax_type": self.softmax_type,
818
819
820
821
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
822
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
823
            "window_size": self.window_size,
824
825
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
826
            "stripe_size": self.stripe_size,
zlsh80826's avatar
zlsh80826 committed
827
        }
828

829
830
831
832
833
834
835
836
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
837
                self.softmax_offset_sharding,
838
                self.seq_desc_sharding,
839
840
841
842
                self.dropout_rng_sharding,
            ],
        )

843
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
844
845
846
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

847
        reference_out = jax_dpa(*args, **kwargs)
848

849
        if self.is_training and self.dropout_prob > 0.0:
850
851
            return

852
853
854
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
855
856
857

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
858

859
        if self.coll_count_ref is not None:
860
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
861
862
863
864
865
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
866
    def test_backward(self):
867
        """
868
869
870
871
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
872
        """
zlsh80826's avatar
zlsh80826 committed
873
874
875

        self._setup_inputs()

Reese Wang's avatar
Reese Wang committed
876
        def grad_func(func, *args, cp_reverse_out=False, **kwargs):
hugo-syn's avatar
hugo-syn committed
877
            # Gradient is small, use a gradient multiplier to amplify the gradient
878
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
879
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
880
                gradient_multiplier /= 10
881
            # Keep only valid result for the gradient
Reese Wang's avatar
Reese Wang committed
882
883
884
885
886
887
888
889
890
891
892
893
            if not cp_reverse_out:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    func(*args, **kwargs),
                )
            else:
                ret_valid = jnp.where(
                    self.pad_q[..., jnp.newaxis, jnp.newaxis],
                    0,
                    self.cp_inverse_reorder_fn(func(*args, **kwargs)),
                )
894
895
896
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
897

898
        args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
899
        customcall_args = [
900
901
902
903
904
905
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
906
            jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
907
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
908
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
909
        ]
910
        kwargs = {
911
912
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
913
            "softmax_type": self.softmax_type,
914
915
916
917
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
918
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
919
            "window_size": self.window_size,
920
921
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
922
            "stripe_size": self.stripe_size,
923
924
        }

925
        # We can compute dBias only for the [1, h, s, s] layout
926
927
928
929
930
931
932
933
934
935
936
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
937

938
939
940
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
941
942
943
944
945
946
947
948
949
950
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    customcall_fused_dpa,
                    q,
                    k,
                    v,
                    bias,
                    softmax_offset,
                    *args,
                    cp_reverse_out=True,
                    **kwargs,
951
952
                ),
                arg_nums,
953
954
955
956
957
958
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
959
                self.softmax_offset_sharding,
960
                self.seq_desc_sharding,
961
962
963
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
964
        )
965
966
        jitted_reference = jit(
            value_and_grad(
967
968
969
                lambda q, k, v, bias, softmax_offset, *args: grad_func(
                    jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs
                ),
970
971
972
                arg_nums,
            )
        )
973

974
        with self.mesh, autocast(mesh_resource=self.mesh_resource):
975
976
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
977
        reference_out, reference_dgrad = jitted_reference(*args)
978

zlsh80826's avatar
zlsh80826 committed
979
        # Skip elementwise comparison when dropout enabled
980
        if self.dropout_prob > 0.0:
981
982
            return

983
984
985
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
986
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
987

988
        def check_dqkv(primitive, reference, pad, idx):
989
990
991
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
992

993
994
995
996
997
998
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

999
1000
1001
1002
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

1003
1004
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
1005

1006
1007
1008
1009
1010
1011
1012
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
1013

1014
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
1015
1016
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

1017
1018
1019
1020
1021
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
1022

1023
            # Assert all masked dbias are 0s
1024
            assert_allclose(
1025
1026
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
1027
1028
                dtype=self.dtype,
            )
1029

1030
            # dbias padded part
1031
            assert_allclose(
1032
1033
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
1034
1035
                dtype=self.dtype,
            )
1036

1037
            # dbias valid part
1038
            assert_allclose(
1039
1040
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
1041
1042
1043
                dtype=self.dtype,
            )

1044
        if self.coll_count_ref is not None:
1045
            with self.mesh, autocast(mesh_resource=self.mesh_resource):
1046
1047
1048
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

1049
1050
1051
1052
1053
1054
1055
1056

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
1057
1058
1059
        pytest.param(
            AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
        ),
1060
1061
    ],
)
1062
1063
1064
1065
1066
1067
1068
1069
@pytest.mark.parametrize(
    "softmax_type",
    [
        pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
        pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"),
        pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"),
    ],
)
1070
@pytest.mark.parametrize(
1071
    "b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
1072
    [
1073
        # large data size + bf16 + qkv packed
1074
        pytest.param(
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
            2,
            2048,
            2048,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.BS3HD,
            id="2-2048-2048-12-12-64-64-BF16-SELF-QKV_PACKED",
1085
1086
1087
        ),
        pytest.param(
            2,
1088
1089
            2048,
            2048,
1090
1091
1092
1093
1094
            12,
            12,
            64,
            64,
            jnp.bfloat16,
1095
1096
            QKVLayout.T3HD,
            id="2-2048-2048-12-12-64-64-BF16-SELF-RAGGED_QKV_PACKED",
1097
        ),
1098
        # mid data size + bf16 + cross attn + kv packed
1099
        pytest.param(
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
            2,
            512,
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.BSHD_BS2HD,
            id="2-512-1024-12-12-64-64-BF16-CROSS-KV_PACKED",
1110
1111
        ),
        pytest.param(
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
            2,
            512,
            1024,
            12,
            12,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.THD_T2HD,
            id="2-512-1024-12-12-64-64-BF16-CROSS-RAGGED_KV_PACKED",
1122
        ),
1123
        # large data size + bf16 + cross attn + diff hidden v dim + qkv separate
1124
        pytest.param(
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
            2,
            2048,
            1024,
            12,
            12,
            64,
            32,
            jnp.bfloat16,
            QKVLayout.BSHD_BSHD_BSHD,
            id="2-2048-1024-12-12-64-32-BF16-CROSS-SEPARATE",
1135
        ),
1136
1137
1138
1139
1140
1141
1142
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
1143
            32,
1144
            jnp.bfloat16,
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
            QKVLayout.THD_THD_THD,
            id="2-2048-1024-12-12-64-32-BF16-CROSS-RAGGED_SEPARATE",
        ),
        # large data size + bf16 + gqa + kv packed
        pytest.param(
            2,
            2048,
            2048,
            12,
            6,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.BSHD_BS2HD,
            id="2-2048-2048-12-6-64-64-BF16-GQA-KV_PACKED",
        ),
        pytest.param(
            2,
            2048,
            2048,
            12,
            6,
            64,
            64,
            jnp.bfloat16,
            QKVLayout.THD_T2HD,
            id="2-2048-2048-12-6-64-64-BF16-GQA-RAGGED_KV_PACKED",
1172
        ),
1173
        # small data size + fp16 + diff hidden v dim + qkv packed
1174
        pytest.param(
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
            4,
            128,
            128,
            16,
            16,
            64,
            32,
            jnp.float16,
            QKVLayout.BS3HD,
            id="4-128-128-16-16-64-32-FP16-SELF-QKV_PACKED",
        ),
        pytest.param(
            4,
            128,
            128,
            16,
            16,
            64,
            32,
            jnp.float16,
            QKVLayout.T3HD,
            id="4-128-128-16-16-64-32-FP16-SELF-RAGGED_QKV_PACKED",
        ),
        # small data size + fp16 + kv packed
        pytest.param(
            4,
            128,
            128,
            16,
            16,
            64,
            64,
            jnp.float16,
            QKVLayout.BSHD_BS2HD,
            id="4-128-128-16-16-64-64-FP16-SELF-KV_PACKED",
        ),
        pytest.param(
            4,
            128,
            128,
            16,
            16,
            64,
            64,
            jnp.float16,
            QKVLayout.THD_T2HD,
            id="4-128-128-16-16-64-64-FP16-SELF-RAGGED_KV_PACKED",
        ),
        # large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
        pytest.param(
            2,
            1024,
            2048,
            12,
            6,
            128,
            64,
            jnp.float16,
            QKVLayout.BSHD_BSHD_BSHD,
            id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-SEPARATE",
        ),
        pytest.param(
            2,
            1024,
            2048,
            12,
            6,
            128,
            64,
            jnp.float16,
            QKVLayout.THD_THD_THD,
            id="2-1024-2048-12-6-128-64-FP16-CROSS-GQA-RAGGED_SEPARATE",
1247
        ),
1248
1249
1250
1251
1252
1253
1254
1255
1256
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
1257
1258
1259
1260
1261
1262
1263
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
1264
1265
1266
1267
1268
1269
1270
1271
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
1272
1273
1274
1275
class TestFusedAttn:
    """
    Fused attention tester
    """
1276

zlsh80826's avatar
zlsh80826 committed
1277
    @staticmethod
1278
1279
1280
1281
1282
1283
1284
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
1295
    def _test_forward(
1296
1297
1298
1299
1300
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1301
1302
        d_qk,
        d_v,
1303
1304
        attn_bias_type,
        attn_mask_type,
1305
        softmax_type,
1306
1307
1308
1309
1310
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
1311
        swa,
1312
        seq_desc_format,
1313
    ):
1314
        """
zlsh80826's avatar
zlsh80826 committed
1315
        Test forward with parameterized configs
1316
1317
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
1318
        """
1319
1320
1321
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1322
1323
1324
1325
1326
1327
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1328
1329
            d_qk,
            d_v,
1330
1331
            attn_bias_type,
            attn_mask_type,
1332
            softmax_type,
1333
1334
1335
1336
1337
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1338
            window_size,
1339
            seq_desc_format,
1340
        )
zlsh80826's avatar
zlsh80826 committed
1341
        runner.test_forward()
1342

zlsh80826's avatar
zlsh80826 committed
1343
    @staticmethod
1344
1345
1346
1347
1348
1349
1350
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1351
1352
1353
1354
1355
1356
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
1357
1358
        d_qk,
        d_v,
1359
1360
        attn_bias_type,
        attn_mask_type,
1361
        softmax_type,
1362
1363
1364
1365
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1366
        swa,
1367
        seq_desc_format,
1368
    ):
zlsh80826's avatar
zlsh80826 committed
1369
1370
1371
        """
        Test backward with parameterized configs
        """
1372
1373
1374
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1375
1376
1377
1378
1379
1380
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
1381
1382
            d_qk,
            d_v,
1383
1384
            attn_bias_type,
            attn_mask_type,
1385
            softmax_type,
1386
1387
1388
1389
1390
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1391
            window_size,
1392
            seq_desc_format,
1393
        )
zlsh80826's avatar
zlsh80826 committed
1394
        runner.test_backward()