test_fused_attn.py 38.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
26
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
27
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
31
    QKVFormat,
32
33
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
34
    fused_attn,
35
    fused_attn_thd,
36
    make_swa_mask,
37
    CPStrategy,
38
)
39
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
40
41
42
43
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
44

45
46
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
47

48

49
@pytest.fixture(autouse=True, scope="module")
50
def init():
51
    """
52
    WAR for CUDA uninitialize error
53
    """
zlsh80826's avatar
zlsh80826 committed
54
55
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
56
57
58
    yield


59
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
60
61
62
63
64
65
66
67
68
69
70
71
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
72
    """
zlsh80826's avatar
zlsh80826 committed
73
    Similar to flax.linen.dot_product_attention but with GQA support
74
    """
zlsh80826's avatar
zlsh80826 committed
75
76
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
77

zlsh80826's avatar
zlsh80826 committed
78
    b, s_q, h_q, d = query.shape
79
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
80
81
82
83
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
84
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
85
86

    if bias is not None:
87
88
89
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
90
        logits = logits + bias
91
92
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
93
94
95
96

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
97
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
98

zlsh80826's avatar
zlsh80826 committed
99
    softmax_out = jax.nn.softmax(logits).astype(dtype)
100

101
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
102
103
104
105
106
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

107
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
108
109
    context = jnp.reshape(context, query.shape)
    return context
110
111


112
113
114
115
116
117
118
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
119
    """
120
121
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
122
    If segment_pos is not provided, aragne of the segment_ids will be applied.
123
    """
124
125
126
127
128
129
130
131
132
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
133
    return inv_causal_mask
134

135

136
@partial(jax.jit, static_argnums=(4, 5))
137
def make_mask(
138
139
140
141
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
142
    attn_mask_type: AttnMaskType,
143
    window_size: Optional[Tuple[int, int]] = None,
144
) -> Array:
145
146
147
148
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
149
150
151
152
153
154
155
156

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
157
    """
158
    # segment masks
159
    inv_mask = make_attention_mask(
160
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
161
    )
162
163
164
165
166
167
168
169
170
171
172

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

    # causal mask
173
174
175
    if attn_mask_type.is_causal():
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
176
        )
177
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
178

179
180
181
    # sliding window mask
    inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
182
183
    mask = jnp.logical_not(inv_mask)
    return mask
184

185

186
187
@jax.jit
def get_seqlens_and_offsets(segment_ids):
188
189
190
191
192
193
194
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
195
        first_column = x[..., :1] != 0
196
197
198
199
200
201
202
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
203
204
205
    seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
206
207
208
209
210
211
212
213
214
215
216
217
218


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
219
    """
zlsh80826's avatar
zlsh80826 committed
220
    JAX native dot product attention implementation
221
    """
222
223
224
225
    output = general_dot_product_attention(
        query,
        key,
        value,
226
227
        bias,
        mask,
228
229
230
231
232
233
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
234
    return output.astype(query.dtype)
235
236


237
238
239
240
241
242
243
244
245
246
247
248
249
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
250
    """
zlsh80826's avatar
zlsh80826 committed
251
    TE customcall dot product attention implementation
252
    """
253
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
254
    match qkv_layout:
255
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
256
257
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
258
259
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
260
261
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
262
263
264
265
266
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
267
    if not qkv_layout.is_thd():
268
269
270
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
271
272
273
274
275
276
277
278
        qkv_args,
        bias,
        seqlens_q,
        seqlens_kv,
        offsets_q,
        offsets_kv,
        dropout_rng,
        **kwargs,
279
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
280
281


282
class BiasShape(Enum):
283
284
285
286
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

287
288
289
290
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
291
292


zlsh80826's avatar
zlsh80826 committed
293
294
@dataclass
class FusedAttnRunner:
295
    """
zlsh80826's avatar
zlsh80826 committed
296
    Fused attention runner
297
    """
298

zlsh80826's avatar
zlsh80826 committed
299
300
301
302
303
304
305
306
307
308
309
310
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
311
    bias_shape: BiasShape
312
    window_size: Optional[Tuple[int, int]] = None
zlsh80826's avatar
zlsh80826 committed
313

314
315
316
317
318
319
320
321
322
323
324
325
326
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

327
328
329
330
331
332
333
334
335
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
        if 90400 <= get_cudnn_version() < 90500:
            return self.num_segments_per_seq
        else:
            # +1 for testing runtime_segments < max_segments
            return self.num_segments_per_seq + 1

zlsh80826's avatar
zlsh80826 committed
336
    def _check_configs(self):
337
        # TODO(rewang): probably adds this in is_fused_attn_available
338
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
339
340
            pytest.skip("THD format requires padding masks.")

341
        if self.qkv_layout.is_qkvpacked():
342
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
343
344
345
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
346

347
348
349
350
351
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

352
353
354
355
356
357
358
359
360
361
362
363
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
364
            (-1, -1) if self.window_size is None else self.window_size,
365
        ).get_fused_attn_backend()
366
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
367
            pytest.skip("Unsupported inputs combination or device compute capability.")
368

369
370
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
371
            and self.bias_shape != BiasShape._1HSS
372
        ):
373
            if self.attn_mask_type.is_padding():
374
                pytest.skip(
375
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
376
                )
377
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
378
379
380
381
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
382

zlsh80826's avatar
zlsh80826 committed
383
384
    def _setup_inputs(self):
        self._check_configs()
385
386
387
388
389
390
391
392

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)

zlsh80826's avatar
zlsh80826 committed
393
394
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
395

zlsh80826's avatar
zlsh80826 committed
396
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
397
398
399
400
401
402
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
403

404
405
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
406
        elif self.bias_shape == BiasShape._1HSS:
407
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
408
        elif self.bias_shape == BiasShape._B1SS:
409
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
410
        elif self.bias_shape == BiasShape._BHSS:
411
412
413
414
415
416
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
417
        elif self.bias_shape == BiasShape._11SS:
418
419
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
420
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
421

422
423
424
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
425
426

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
427
            if self.bias_shape == BiasShape._1HSS:
428
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
429
430
431
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
432
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
433
434
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
435
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
436
437
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
438
439
440
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
441
442
        else:
            self.bias = None
443

444
        if self.attn_mask_type.is_padding():
445
            pad_ratio = 0.3
446
447
        else:
            pad_ratio = 0.0
448

zlsh80826's avatar
zlsh80826 committed
449
450
451
452
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
453
454
455
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
456
457
458
459
460
461
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
462
463
464
465
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
466
467
            segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
468
469
470
471
472
473
474
475
476
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

477
478
479
480
481
482
483
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
484
485
486
487
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
488
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
489
                    if with_segment_pad:
490
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
491
492
493
494
495
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

496
497
498
499
500
501
502
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
503
            self.num_segments_per_seq = 2
504
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
505
506
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
507
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
508
509
510
511
            if self.qkv_layout == QKVLayout.T3HD:
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
512
            else:
513
514
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
                min_segment_len = None if self.window_size is None else self.seqlens_q
515
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
516
517
518
519
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
520
                    min_segment_len=min_segment_len,
521
                )
522
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
523
524
        else:
            self.num_segments_per_seq = 1
525
526
527
528
529
530
531
532
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
533

534
        # For reference code
535
        self.mask = make_mask(
536
537
538
539
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
540
            self.attn_mask_type,
541
            self.window_size,
542
        )
543

544
        if self.qkv_layout.is_thd():
545
546
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
547
548
549
550
551
552
553
            self.mask_for_customcall = make_mask(
                self.segment_ids_q,
                self.segment_ids_kv,
                self.segment_pos_q,
                self.segment_pos_kv,
                self.attn_mask_type,
            )
554

zlsh80826's avatar
zlsh80826 committed
555
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
556
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
557

558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
            self.mesh_resource.tp_resource,
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

        self.mask_pspec = PartitionSpec(
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
        self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec)

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
                None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

        # Softmax aux sharding

        if self.cp_size > 1 and self.cp_load_balanced:
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                cp_size=self.cp_size,
                tensor_format=self.qkv_layout.get_qkv_format(),
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                cp_size=self.cp_size,
                tensor_format=self.qkv_layout.get_qkv_format(),
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

zlsh80826's avatar
zlsh80826 committed
618
619
620
621
622
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
623

624
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
625

626
        customcall_args = [
627
628
629
630
631
632
633
634
635
636
637
638
639
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
            jax.device_put(self.mask_for_customcall, self.mask_sharding),
            jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
            jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
            jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
            jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
640
        ]
zlsh80826's avatar
zlsh80826 committed
641
        kwargs = {
642
643
644
645
646
647
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
648
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
649
            "window_size": self.window_size,
650
651
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
652
        }
653

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
                self.mask_sharding,
                self.seq_length_offset_sharding,
                self.seq_length_offset_sharding,
                self.seq_length_offset_sharding,
                self.seq_length_offset_sharding,
                self.dropout_rng_sharding,
            ],
        )

        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

675
        reference_out = jax_dpa(*args, **kwargs)
676

677
        if self.is_training and self.dropout_prob > 0.0:
678
679
            return

680
681
682
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
683
684
685

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
686

687
688
689
690
691
692
693
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
694
    def test_backward(self):
695
        """
696
697
698
699
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
700
        """
zlsh80826's avatar
zlsh80826 committed
701
702
703
704

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
705
            # Gradient is small, use a gradient multiplier to amplify the gradient
706
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
707
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
708
                gradient_multiplier /= 10
709
            # Keep only valid result for the gradient
710
711
712
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
713
714
715
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
716

717
718
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
719
720
721
722
723
724
725
726
727
728
729
730
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
            jax.device_put(self.mask_for_customcall, self.mask_sharding),
            jax.device_put(self.seqlens_q, self.seq_length_offset_sharding),
            jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding),
            jax.device_put(self.offsets_q, self.seq_length_offset_sharding),
            jax.device_put(self.offsets_kv, self.seq_length_offset_sharding),
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
731
        ]
732
        kwargs = {
733
734
735
736
737
738
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
739
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
740
            "window_size": self.window_size,
741
742
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
743
744
        }

745
        # We can compute dBias only for the [1, h, s, s] layout
746
747
748
749
750
751
752
753
754
755
756
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
757

758
759
760
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
761
762
763
764
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
765
766
767
768
769
770
771
772
773
774
775
776
777
778
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
                self.mask_sharding,
                self.seq_length_offset_sharding,
                self.seq_length_offset_sharding,
                self.seq_length_offset_sharding,
                self.seq_length_offset_sharding,
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
779
        )
780
781
        jitted_reference = jit(
            value_and_grad(
782
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
783
784
785
                arg_nums,
            )
        )
786

787
788
789
        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
790
        reference_out, reference_dgrad = jitted_reference(*args)
791

zlsh80826's avatar
zlsh80826 committed
792
        # Skip elementwise comparison when dropout enabled
793
        if self.dropout_prob > 0.0:
794
795
            return

796
797
798
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
799
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
800

801
        def check_dqkv(primitive, reference, pad, idx):
802
803
804
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
805

806
807
808
809
810
811
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

812
813
814
815
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

816
817
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
818

819
820
821
822
823
824
825
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
826

827
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
828
829
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

830
831
832
833
834
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
835

836
            # Assert all masked dbias are 0s
837
            assert_allclose(
838
839
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
840
841
                dtype=self.dtype,
            )
842

843
            # dbias padded part
844
            assert_allclose(
845
846
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
847
848
                dtype=self.dtype,
            )
849

850
            # dbias valid part
851
            assert_allclose(
852
853
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
854
855
856
                dtype=self.dtype,
            )

857
858
859
860
861
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
878
879
880
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
881
882
883
    ],
)
@pytest.mark.parametrize(
884
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
885
    [
886
887
888
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
889
        pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
890
891
892
893
894
895
896
897
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
898
            id="2-2048-1024-12-12-64-BF16-CROSS",
899
        ),
900
901
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
902
903
904
905
906
907
908
909
910
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
911
912
913
914
915
916
917
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
918
919
920
921
class TestFusedAttn:
    """
    Fused attention tester
    """
922

zlsh80826's avatar
zlsh80826 committed
923
    @staticmethod
924
925
926
927
928
929
930
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
931
932
933
934
935
936
937
938
939
940
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
941
    def _test_forward(
942
943
944
945
946
947
948
949
950
951
952
953
954
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
955
        swa,
956
    ):
957
        """
zlsh80826's avatar
zlsh80826 committed
958
        Test forward with parameterized configs
959
960
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
961
        """
962
963
964
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
965
966
967
968
969
970
971
972
973
974
975
976
977
978
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
979
            window_size,
980
        )
zlsh80826's avatar
zlsh80826 committed
981
        runner.test_forward()
982

zlsh80826's avatar
zlsh80826 committed
983
    @staticmethod
984
985
986
987
988
989
990
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1004
        swa,
1005
    ):
zlsh80826's avatar
zlsh80826 committed
1006
1007
1008
        """
        Test backward with parameterized configs
        """
1009
1010
1011
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1026
            window_size,
1027
        )
zlsh80826's avatar
zlsh80826 committed
1028
        runner.test_backward()