test_fused_attn.py 39.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum, auto
6
from dataclasses import dataclass, field
zlsh80826's avatar
zlsh80826 committed
7
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional, Dict
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
22
from jax.sharding import Mesh, NamedSharding, PartitionSpec
zlsh80826's avatar
zlsh80826 committed
23
from jax.typing import ArrayLike, DTypeLike
24

25
26
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.sharding import MeshResource
27
28
29
30
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
31
32
    reorder_causal_load_balancing,
    inverse_reorder_causal_load_balancing,
33
    fused_attn,
34
    make_swa_mask,
35
    SequenceDescriptor,
36
    CPStrategy,
37
)
38
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
39
40
41
42
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
43

44
45
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
46

47

48
@pytest.fixture(autouse=True, scope="module")
49
def init():
50
    """
51
    WAR for CUDA uninitialize error
52
    """
zlsh80826's avatar
zlsh80826 committed
53
54
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
55
56
57
    yield


58
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
59
60
61
62
63
64
65
66
67
68
69
70
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
71
    """
zlsh80826's avatar
zlsh80826 committed
72
    Similar to flax.linen.dot_product_attention but with GQA support
73
    """
zlsh80826's avatar
zlsh80826 committed
74
75
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
76

zlsh80826's avatar
zlsh80826 committed
77
    b, s_q, h_q, d = query.shape
78
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
79
80
81
82
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
83
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
84
85

    if bias is not None:
86
87
88
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
89
        logits = logits + bias
90
91
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
92
93
94
95

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
96
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
97

zlsh80826's avatar
zlsh80826 committed
98
    softmax_out = jax.nn.softmax(logits).astype(dtype)
99

100
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
101
102
103
104
105
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

106
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
107
108
    context = jnp.reshape(context, query.shape)
    return context
109
110


111
112
113
114
115
116
117
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
118
    """
119
120
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
121
    If segment_pos is not provided, aragne of the segment_ids will be applied.
122
    """
123
124
125
126
127
128
129
130
131
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
132
    return inv_causal_mask
133

134

135
@partial(jax.jit, static_argnums=(4, 5))
136
def make_mask(
137
138
139
140
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
141
    attn_mask_type: AttnMaskType,
142
    window_size: Optional[Tuple[int, int]] = None,
143
) -> Array:
144
145
146
147
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
148
149
150
151
152
153
154
155

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
156
    """
157
    # segment masks
158
    inv_mask = make_attention_mask(
159
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
160
    )
161
162
163
164
165
166
167
168
169
170
171

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

    # causal mask
172
173
174
    if attn_mask_type.is_causal():
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
175
        )
176
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
177

178
179
180
    # sliding window mask
    inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
181
182
    mask = jnp.logical_not(inv_mask)
    return mask
183

184

185
186
@jax.jit
def get_seqlens_and_offsets(segment_ids):
187
188
189
190
191
192
193
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
194
        first_column = x[..., :1] != 0
195
196
197
198
199
200
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
201
202
    offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1)
    seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1)
203
204
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
205
206
207
208
209
210
211
212
213
214
215
216
217


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
218
    """
zlsh80826's avatar
zlsh80826 committed
219
    JAX native dot product attention implementation
220
    """
221
222
223
224
    output = general_dot_product_attention(
        query,
        key,
        value,
225
226
        bias,
        mask,
227
228
229
230
231
232
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
233
    return output.astype(query.dtype)
234
235


236
237
238
239
240
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
241
    sequence_descriptor,
242
243
244
    dropout_rng,
    **kwargs,
):
245
    """
zlsh80826's avatar
zlsh80826 committed
246
    TE customcall dot product attention implementation
247
    """
248
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
249
    match qkv_layout:
250
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
251
252
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
253
254
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
255
256
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
257
258
259
260
261
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
262
263
264
    return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
        query.dtype
    )
zlsh80826's avatar
zlsh80826 committed
265
266


267
class BiasShape(Enum):
268
269
270
271
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

272
273
274
275
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
276
277


278
279
280
281
282
283
class SeqDescFormat(Enum):
    Mask = auto()
    Seqlens = auto()
    SegmentIDs = auto()


zlsh80826's avatar
zlsh80826 committed
284
285
@dataclass
class FusedAttnRunner:
286
    """
zlsh80826's avatar
zlsh80826 committed
287
    Fused attention runner
288
    """
289

zlsh80826's avatar
zlsh80826 committed
290
291
292
293
294
295
296
297
298
299
300
301
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
302
    bias_shape: BiasShape
303
304
    window_size: Tuple[int, int]
    seq_desc_format: SeqDescFormat
zlsh80826's avatar
zlsh80826 committed
305

306
307
308
309
310
311
312
313
314
315
316
317
318
    # Specifies sharding resources for distributed tests
    number_of_devices: int = 1
    mesh_shape: tuple[int, ...] = (1, 1, 1)
    mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
    mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

    # Context parallel aux arguments
    cp_strategy: CPStrategy = CPStrategy.DEFAULT
    cp_load_balanced: bool = True

    # dictionary of expected collective comm bytes
    coll_count_ref: Optional[Dict[str, int]] = None

319
320
321
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
322
323
324
325
326
327
        if self.qkv_layout.is_thd():
            if 90400 <= get_cudnn_version() < 90500:
                return self.num_segments_per_seq
            else:
                # +1 for testing runtime_segments < max_segments
                return self.num_segments_per_seq + 1
328
        else:
329
            return 1
330

zlsh80826's avatar
zlsh80826 committed
331
    def _check_configs(self):
332
        # TODO(rewang): probably adds this in is_fused_attn_available
333
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
334
335
            pytest.skip("THD format requires padding masks.")

336
        if self.qkv_layout.is_qkvpacked():
337
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
338
339
340
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
341

342
343
344
345
346
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

347
348
349
350
351
352
353
354
355
356
357
358
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
359
            (-1, -1) if self.window_size is None else self.window_size,
360
        ).get_fused_attn_backend()
361
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
362
            pytest.skip("Unsupported inputs combination or device compute capability.")
363

364
365
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
366
            and self.bias_shape != BiasShape._1HSS
367
        ):
368
            if self.attn_mask_type.is_padding():
369
                pytest.skip(
370
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
371
                )
372
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
373
374
375
376
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
377

zlsh80826's avatar
zlsh80826 committed
378
379
    def _setup_inputs(self):
        self._check_configs()
380
381
382
383
384
385
386
387

        # Create a mesh for distributed tests
        self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
        self.mesh = Mesh(self.devices, self.mesh_axes)
        self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
        self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
        self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)

zlsh80826's avatar
zlsh80826 committed
388
389
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
390

zlsh80826's avatar
zlsh80826 committed
391
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
392
393
394
395
396
397
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
398

399
400
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
401
        elif self.bias_shape == BiasShape._1HSS:
402
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
403
        elif self.bias_shape == BiasShape._B1SS:
404
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
405
        elif self.bias_shape == BiasShape._BHSS:
406
407
408
409
410
411
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
412
        elif self.bias_shape == BiasShape._11SS:
413
414
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
415
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
416

417
418
419
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
420
421

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
422
            if self.bias_shape == BiasShape._1HSS:
423
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
424
425
426
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
427
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
428
429
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
430
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
431
432
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
433
434
435
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
436
437
        else:
            self.bias = None
438

439
        if self.attn_mask_type.is_padding():
440
            pad_ratio = 0.3
441
442
        else:
            pad_ratio = 0.0
443

zlsh80826's avatar
zlsh80826 committed
444
445
446
447
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
448
449
450
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
451
452
453
454
455
456
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
457
458
459
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
460
461
            segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32)
            segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32)
462
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
463
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
464
            segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32)
465
466
467
468
469
470
471

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

472
473
474
475
476
477
478
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
479
480
481
482
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
483
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
484
                    if with_segment_pad:
485
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
486
487
488
489
490
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

491
492
493
494
495
496
497
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
498
            self.num_segments_per_seq = 2
499
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
500
501
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
502
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
503
504
505
506
            if self.qkv_layout == QKVLayout.T3HD:
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
507
            else:
508
509
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
                min_segment_len = None if self.window_size is None else self.seqlens_q
510
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
511
512
513
514
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
515
                    min_segment_len=min_segment_len,
516
                )
517
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
518
519
        else:
            self.num_segments_per_seq = 1
520
521
522
523
524
525
526
527
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
528

529
        # For reference code
530
        self.mask = make_mask(
531
532
533
534
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
535
            self.attn_mask_type,
536
            self.window_size,
537
        )
538

539
        # Test different input formats
540
        if self.qkv_layout.is_thd():
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
                    pytest.skip("THD doesn't support mask input")
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets(
                        (self.seqlens_q, self.seqlens_kv),
                        (self.offsets_q, self.offsets_kv),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        (self.segment_pos_q, self.segment_pos_kv),
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
556
        else:
557
558
            match self.seq_desc_format:
                case SeqDescFormat.Mask:
559
560
561
562
563
564
565
566
567
568
                    if self.attn_mask_type == AttnMaskType.NO_MASK:
                        self.sequence_desciptor = None
                    else:
                        self.sequence_desciptor = make_mask(
                            self.segment_ids_q,
                            self.segment_ids_kv,
                            self.segment_pos_q,
                            self.segment_pos_kv,
                            self.attn_mask_type,
                        )
569
570
571
572
573
574
575
576
577
578
579
580
581
582
                case SeqDescFormat.Seqlens:
                    self.sequence_desciptor = SequenceDescriptor.from_seqlens(
                        (
                            self.segment_ids_q.sum(axis=-1).astype(jnp.int32),
                            self.segment_ids_kv.sum(axis=-1).astype(jnp.int32),
                        ),
                    )
                case SeqDescFormat.SegmentIDs:
                    self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
                        (self.segment_ids_q, self.segment_ids_kv),
                        None,
                    )
                case _:
                    raise ValueError(f"Unknown {self.seq_desc_format=}")
583

zlsh80826's avatar
zlsh80826 committed
584
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
585
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
586

587
588
589
590
591
592
593
594
595
596
        # Setup distributed sharding specs
        # Setup shardings for distributed tests
        self.qkvo_psec = PartitionSpec(
            self.mesh_resource.dp_resource,
            self.mesh_resource.cp_resource,
            self.mesh_resource.tp_resource,
            None,
        )
        self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec)

597
        mask_pspec = PartitionSpec(
598
599
            self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
        )
600
601
602
603
604
605
606
607
608
609
610
611
        self.mask_sharding = NamedSharding(self.mesh, mask_pspec)

        match self.seq_desc_format:
            case SeqDescFormat.Mask:
                self.seq_desc_sharding = self.mask_sharding
            case _:

                def to_dp_shardings(x):
                    pspec = PartitionSpec(self.mesh_resource.dp_resource)
                    return NamedSharding(self.mesh, pspec)

                self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657

        if self.bias_shape == BiasShape._1HSS:
            self.bias_pspec = PartitionSpec(
                None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._B1SS:
            self.bias_pspec = PartitionSpec(
                self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None
            )
        elif self.bias_shape == BiasShape._11SS:
            self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        else:
            self.bias_pspec = PartitionSpec()
        self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)

        self.dropout_rng_pspec = PartitionSpec(
            None,
        )
        self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec)

        self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None)
        self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec)

        # [batch][max_segments_per_batch]
        # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset.
        self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None)
        self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec)

        # Softmax aux sharding

        if self.cp_size > 1 and self.cp_load_balanced:
            self.cp_reorder_fn = partial(
                reorder_causal_load_balancing,
                cp_size=self.cp_size,
                tensor_format=self.qkv_layout.get_qkv_format(),
            )
            self.cp_inverse_reorder_fn = partial(
                inverse_reorder_causal_load_balancing,
                cp_size=self.cp_size,
                tensor_format=self.qkv_layout.get_qkv_format(),
            )
        else:
            # no-ops for non cp or non load balanced
            self.cp_reorder_fn = lambda x: x
            self.cp_inverse_reorder_fn = lambda x: x

zlsh80826's avatar
zlsh80826 committed
658
659
660
661
662
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
663

664
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
665

666
        customcall_args = [
667
668
669
670
671
672
673
            # Put test data onto each GPU for distributed.
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
674
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
675
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
676
        ]
zlsh80826's avatar
zlsh80826 committed
677
        kwargs = {
678
679
680
681
682
683
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
684
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
685
            "window_size": self.window_size,
686
687
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
zlsh80826's avatar
zlsh80826 committed
688
        }
689

690
691
692
693
694
695
696
697
        customcall_fused_dpa_jit = jit(
            partial(customcall_fused_dpa, **kwargs),
            static_argnames=kwargs.keys(),
            in_shardings=[
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
698
                self.seq_desc_sharding,
699
700
701
702
703
704
705
706
                self.dropout_rng_sharding,
            ],
        )

        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out = customcall_fused_dpa_jit(*customcall_args)
            primitive_out = self.cp_inverse_reorder_fn(primitive_out)

707
        reference_out = jax_dpa(*args, **kwargs)
708

709
        if self.is_training and self.dropout_prob > 0.0:
710
711
            return

712
713
714
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
715
716
717

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
718

719
720
721
722
723
724
725
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = (
                    customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
                )
            assert_equal_collectives(target_hlo, self.coll_count_ref)

zlsh80826's avatar
zlsh80826 committed
726
    def test_backward(self):
727
        """
728
729
730
731
        Test value_and_grad with JIT, which includes both forward and backward.

        If coll_count_ref is not None then the HLO of the backwrds function
        HLO will be examined for the expected comms.
732
        """
zlsh80826's avatar
zlsh80826 committed
733
734
735
736

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
737
            # Gradient is small, use a gradient multiplier to amplify the gradient
738
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
739
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
740
                gradient_multiplier /= 10
741
            # Keep only valid result for the gradient
742
743
744
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
745
746
747
            return (
                jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
            ).astype(self.dtype)
748

749
750
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
751
752
753
754
755
756
            # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
            # THD params once we support those features on CP.
            jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
            jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
            jax.device_put(self.bias, self.bias_sharding),
757
            jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
758
            jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
759
        ]
760
        kwargs = {
761
762
763
764
765
766
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
767
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
768
            "window_size": self.window_size,
769
770
            "context_parallel_strategy": self.cp_strategy,
            "context_parallel_causal_load_balanced": self.cp_load_balanced,
771
772
        }

773
        # We can compute dBias only for the [1, h, s, s] layout
774
775
776
777
778
779
780
781
782
783
784
        if self.bias_shape == BiasShape._1HSS:
            arg_nums = (0, 1, 2, 3)
            grad_shardings = (
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
            )
        else:
            arg_nums = (0, 1, 2)
            grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)
785

786
787
788
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
789
790
791
792
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
793
794
795
796
797
798
            ),
            in_shardings=(
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.qkvo_sharding,
                self.bias_sharding,
799
                self.seq_desc_sharding,
800
801
802
                self.dropout_rng_sharding,
            ),
            out_shardings=(None, grad_shardings),
803
        )
804
805
        jitted_reference = jit(
            value_and_grad(
806
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
807
808
809
                arg_nums,
            )
        )
810

811
812
813
        with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
            primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

zlsh80826's avatar
zlsh80826 committed
814
        reference_out, reference_dgrad = jitted_reference(*args)
815

zlsh80826's avatar
zlsh80826 committed
816
        # Skip elementwise comparison when dropout enabled
817
        if self.dropout_prob > 0.0:
818
819
            return

820
821
822
        print_debug_tensor_stats(f"primitive_out", primitive_out)
        print_debug_tensor_stats(f"reference_grad_valid", reference_out)
        print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
823
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
824

825
        def check_dqkv(primitive, reference, pad, idx):
826
827
828
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
829

830
831
832
833
834
835
            print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
            print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
            print_debug_tensor_stats(
                f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
            )

836
837
838
839
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

840
841
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
842

843
844
845
846
847
848
849
        primitive_dq = self.cp_inverse_reorder_fn(primitive_dq)
        primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
        primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

        check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)
850

851
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
852
853
            # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.

854
855
856
857
858
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
859

860
            # Assert all masked dbias are 0s
861
            assert_allclose(
862
863
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
864
865
                dtype=self.dtype,
            )
866

867
            # dbias padded part
868
            assert_allclose(
869
870
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
871
872
                dtype=self.dtype,
            )
873

874
            # dbias valid part
875
            assert_allclose(
876
877
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
878
879
880
                dtype=self.dtype,
            )

881
882
883
884
885
        if self.coll_count_ref is not None:
            with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
                target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
            assert_equal_collectives(target_hlo, self.coll_count_ref)

886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901

@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
902
903
904
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
905
906
907
    ],
)
@pytest.mark.parametrize(
908
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
909
    [
910
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
911
912
913
914
915
916
917
918
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
919
            id="2-2048-1024-12-12-64-BF16-CROSS",
920
        ),
921
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
922
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
923
924
925
926
927
928
929
930
931
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
932
933
934
935
936
937
938
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
939
940
941
942
943
944
945
946
@pytest.mark.parametrize(
    "seq_desc_format",
    [
        pytest.param(SeqDescFormat.Mask, id="Mask"),
        pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
        pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
947
948
949
950
class TestFusedAttn:
    """
    Fused attention tester
    """
951

zlsh80826's avatar
zlsh80826 committed
952
    @staticmethod
953
954
955
956
957
958
959
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
960
961
962
963
964
965
966
967
968
969
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
970
    def _test_forward(
971
972
973
974
975
976
977
978
979
980
981
982
983
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
984
        swa,
985
        seq_desc_format,
986
    ):
987
        """
zlsh80826's avatar
zlsh80826 committed
988
        Test forward with parameterized configs
989
990
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
991
        """
992
993
994
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
1009
            window_size,
1010
            seq_desc_format,
1011
        )
zlsh80826's avatar
zlsh80826 committed
1012
        runner.test_forward()
1013

zlsh80826's avatar
zlsh80826 committed
1014
    @staticmethod
1015
1016
1017
1018
1019
1020
1021
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
1035
        swa,
1036
        seq_desc_format,
1037
    ):
zlsh80826's avatar
zlsh80826 committed
1038
1039
1040
        """
        Test backward with parameterized configs
        """
1041
1042
1043
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
1058
            window_size,
1059
            seq_desc_format,
1060
        )
zlsh80826's avatar
zlsh80826 committed
1061
        runner.test_backward()