test_fused_attn.py 29.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
22
from jax.typing import ArrayLike, DTypeLike
23

24
25
26
27
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
28
    QKVFormat,
29
    fused_attn,
30
    fused_attn_thd,
31
    make_swa_mask,
32
)
33
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
34
35
36
37
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
38
39

from utils import assert_allclose
40

41

42
@pytest.fixture(autouse=True, scope="module")
43
def init():
44
    """
45
    WAR for CUDA uninitialize error
46
    """
zlsh80826's avatar
zlsh80826 committed
47
48
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
49
50
51
    yield


52
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
53
54
55
56
57
58
59
60
61
62
63
64
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
65
    """
zlsh80826's avatar
zlsh80826 committed
66
    Similar to flax.linen.dot_product_attention but with GQA support
67
    """
zlsh80826's avatar
zlsh80826 committed
68
69
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
70

zlsh80826's avatar
zlsh80826 committed
71
    b, s_q, h_q, d = query.shape
72
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
73
74
75
76
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
77
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
78
79

    if bias is not None:
80
81
82
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
83
        logits = logits + bias
84
85
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
86
87
88
89

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
90
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
91

zlsh80826's avatar
zlsh80826 committed
92
    softmax_out = jax.nn.softmax(logits).astype(dtype)
93

94
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
95
96
97
98
99
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

100
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
101
102
    context = jnp.reshape(context, query.shape)
    return context
103
104


105
106
107
108
109
110
111
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
112
    """
113
114
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
115
    If segment_pos is not provided, aragne of the segment_ids will be applied.
116
    """
117
118
119
120
121
122
123
124
125
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
126
    return inv_causal_mask
127

128

129
@partial(jax.jit, static_argnums=(4, 5))
130
def make_mask(
131
132
133
134
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
135
    attn_mask_type: AttnMaskType,
136
    window_size: Optional[Tuple[int, int]] = None,
137
) -> Array:
138
139
140
141
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
142
143
144
145
146
147
148
149

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
150
    """
151
    # segment masks
152
    inv_mask = make_attention_mask(
153
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
154
    )
155
156
157
158
159
160
161
162
163
164
165

    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )

    # causal mask
166
167
168
    if attn_mask_type.is_causal():
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
169
        )
170
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
171

172
173
174
    # sliding window mask
    inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
    inv_mask = combine_masks(inv_mask, inv_swa_mask)
175
176
    mask = jnp.logical_not(inv_mask)
    return mask
177

178

179
180
@jax.jit
def get_seqlens_and_offsets(segment_ids):
181
182
183
184
185
186
187
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
188
        first_column = x[..., :1] != 0
189
190
191
192
193
194
195
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
196
197
198
    seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
199
200
201
202
203
204
205
206
207
208
209
210
211


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
212
    """
zlsh80826's avatar
zlsh80826 committed
213
    JAX native dot product attention implementation
214
    """
215
216
217
218
    output = general_dot_product_attention(
        query,
        key,
        value,
219
220
        bias,
        mask,
221
222
223
224
225
226
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
227
    return output.astype(query.dtype)
228
229


230
231
232
233
234
235
236
237
238
239
240
241
242
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
243
    """
zlsh80826's avatar
zlsh80826 committed
244
    TE customcall dot product attention implementation
245
    """
246
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
247
    match qkv_layout:
248
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
249
250
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
251
252
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
253
254
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
255
256
257
258
259
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
260
    if not qkv_layout.is_thd():
261
262
263
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
264
265
266
267
268
269
270
271
        qkv_args,
        bias,
        seqlens_q,
        seqlens_kv,
        offsets_q,
        offsets_kv,
        dropout_rng,
        **kwargs,
272
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
273
274


275
class BiasShape(Enum):
276
277
278
279
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

280
281
282
283
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
284
285


zlsh80826's avatar
zlsh80826 committed
286
287
@dataclass
class FusedAttnRunner:
288
    """
zlsh80826's avatar
zlsh80826 committed
289
    Fused attention runner
290
    """
291

zlsh80826's avatar
zlsh80826 committed
292
293
294
295
296
297
298
299
300
301
302
303
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
304
    bias_shape: BiasShape
305
    window_size: Optional[Tuple[int, int]] = None
zlsh80826's avatar
zlsh80826 committed
306

307
308
309
310
311
312
313
314
315
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
        if 90400 <= get_cudnn_version() < 90500:
            return self.num_segments_per_seq
        else:
            # +1 for testing runtime_segments < max_segments
            return self.num_segments_per_seq + 1

zlsh80826's avatar
zlsh80826 committed
316
    def _check_configs(self):
317
        # TODO(rewang): probably adds this in is_fused_attn_available
318
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
319
320
            pytest.skip("THD format requires padding masks.")

321
        if self.qkv_layout.is_qkvpacked():
322
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
323
324
325
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
326

327
328
329
330
331
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

332
333
334
335
336
337
338
339
340
341
342
343
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
344
            (-1, -1) if self.window_size is None else self.window_size,
345
        ).get_fused_attn_backend()
346
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
347
            pytest.skip("Unsupported inputs combination or device compute capability.")
348

349
350
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
351
            and self.bias_shape != BiasShape._1HSS
352
        ):
353
            if self.attn_mask_type.is_padding():
354
                pytest.skip(
355
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
356
                )
357
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
358
359
360
361
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
362

zlsh80826's avatar
zlsh80826 committed
363
364
365
366
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
367

zlsh80826's avatar
zlsh80826 committed
368
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
369
370
371
372
373
374
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
375

376
377
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
378
        elif self.bias_shape == BiasShape._1HSS:
379
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
380
        elif self.bias_shape == BiasShape._B1SS:
381
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
382
        elif self.bias_shape == BiasShape._BHSS:
383
384
385
386
387
388
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
389
        elif self.bias_shape == BiasShape._11SS:
390
391
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
392
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
393

394
395
396
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
397
398

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
399
            if self.bias_shape == BiasShape._1HSS:
400
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
401
402
403
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
404
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
405
406
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
407
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
408
409
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
410
411
412
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
413
414
        else:
            self.bias = None
415

416
        if self.attn_mask_type.is_padding():
417
            pad_ratio = 0.3
418
419
        else:
            pad_ratio = 0.0
420

zlsh80826's avatar
zlsh80826 committed
421
422
423
424
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
425
426
427
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
428
429
430
431
432
433
            batch_size,
            sequence_length,
            num_segments,
            seed,
            with_segment_pad=True,
            min_segment_len=None,
434
435
436
437
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
438
439
            segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
440
441
442
443
444
445
446
447
448
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

449
450
451
452
453
454
455
                for seg_id in range(num_segments):
                    # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
                    # TODO(rewang): Remove this constrain after cuDNN supports
                    min_segment_size = 1
                    if min_segment_len is not None:
                        min_segment_size = min_segment_len[i][seg_id]
                    segment_size = rng.integers(min_segment_size, max_segment_size + 1)
456
457
458
459
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
460
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
461
                    if with_segment_pad:
462
                        num_valid = rng.integers(min_segment_size, segment_size + 1)
463
464
465
466
467
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

468
469
470
471
472
473
474
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
475
            self.num_segments_per_seq = 2
476
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
477
478
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
479
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
480
481
482
483
            if self.qkv_layout == QKVLayout.T3HD:
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
484
            else:
485
486
                # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
                min_segment_len = None if self.window_size is None else self.seqlens_q
487
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
488
489
490
491
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
492
                    min_segment_len=min_segment_len,
493
                )
494
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
495
496
        else:
            self.num_segments_per_seq = 1
497
498
499
500
501
502
503
504
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
505

506
        # For reference code
507
        self.mask = make_mask(
508
509
510
511
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
512
            self.attn_mask_type,
513
            self.window_size,
514
        )
515

516
        if self.qkv_layout.is_thd():
517
518
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
519
520
521
522
523
524
525
            self.mask_for_customcall = make_mask(
                self.segment_ids_q,
                self.segment_ids_kv,
                self.segment_pos_q,
                self.segment_pos_kv,
                self.attn_mask_type,
            )
526

zlsh80826's avatar
zlsh80826 committed
527
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
528
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
529

zlsh80826's avatar
zlsh80826 committed
530
531
532
533
534
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
535

536
537
538
539
540
541
542
543
544
545
546
547
548
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
zlsh80826's avatar
zlsh80826 committed
549
        kwargs = {
550
551
552
553
554
555
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
556
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
557
            "window_size": self.window_size,
zlsh80826's avatar
zlsh80826 committed
558
        }
559

zlsh80826's avatar
zlsh80826 committed
560
        # Convert the outputs to float32 for the elementwise comparison
561
562
        primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
        reference_out = jax_dpa(*args, **kwargs)
563

564
        if self.is_training and self.dropout_prob > 0.0:
565
566
            return

567
568
569
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
570
571
572

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
573

zlsh80826's avatar
zlsh80826 committed
574
    def test_backward(self):
575
        """
zlsh80826's avatar
zlsh80826 committed
576
        Test value_and_grad with JIT, which includes both forward and backward
577
        """
zlsh80826's avatar
zlsh80826 committed
578
579
580
581

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
582
            # Gradient is small, use a gradient multiplier to amplify the gradient
583
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
584
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
585
                gradient_multiplier /= 10
586
            # Keep only valid result for the gradient
587
588
589
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
zlsh80826's avatar
zlsh80826 committed
590
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
591

592
593
594
595
596
597
598
599
600
601
602
603
604
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
605
        kwargs = {
606
607
608
609
610
611
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
612
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
613
            "window_size": self.window_size,
614
615
        }

616
        # We can compute dBias only for the [1, h, s, s] layout
617
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
618

619
620
621
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
622
623
624
625
626
627
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
628
629
        jitted_reference = jit(
            value_and_grad(
630
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
631
632
633
                arg_nums,
            )
        )
634

635
        primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
zlsh80826's avatar
zlsh80826 committed
636
        reference_out, reference_dgrad = jitted_reference(*args)
637

zlsh80826's avatar
zlsh80826 committed
638
        # Skip elementwise comparison when dropout enabled
639
        if self.dropout_prob > 0.0:
640
641
            return

642
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
643

644
645
646
647
        def check_dqkv(primitive, reference, pad):
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
648

649
650
651
652
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

653
654
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
655

656
657
658
        check_dqkv(primitive_dq, reference_dq, self.pad_q)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv)
659

660
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
661
662
663
664
665
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
666

667
            # Assert all masked dbias are 0s
668
            assert_allclose(
669
670
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
671
672
                dtype=self.dtype,
            )
673

674
            # dbias padded part
675
            assert_allclose(
676
677
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
678
679
                dtype=self.dtype,
            )
680

681
            # dbias valid part
682
            assert_allclose(
683
684
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
704
705
706
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
707
708
709
    ],
)
@pytest.mark.parametrize(
710
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
711
    [
712
713
714
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
715
        pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
716
717
718
719
720
721
722
723
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
724
            id="2-2048-1024-12-12-64-BF16-CROSS",
725
        ),
726
727
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
728
729
730
731
732
733
734
735
736
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
737
738
739
740
741
742
743
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
744
745
746
747
class TestFusedAttn:
    """
    Fused attention tester
    """
748

zlsh80826's avatar
zlsh80826 committed
749
    @staticmethod
750
751
752
753
754
755
756
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
757
758
759
760
761
762
763
764
765
766
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
767
    def _test_forward(
768
769
770
771
772
773
774
775
776
777
778
779
780
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
781
        swa,
782
    ):
783
        """
zlsh80826's avatar
zlsh80826 committed
784
        Test forward with parameterized configs
785
786
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
787
        """
788
789
790
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
791
792
793
794
795
796
797
798
799
800
801
802
803
804
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
805
            window_size,
806
        )
zlsh80826's avatar
zlsh80826 committed
807
        runner.test_forward()
808

zlsh80826's avatar
zlsh80826 committed
809
    @staticmethod
810
811
812
813
814
815
816
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
817
818
819
820
821
822
823
824
825
826
827
828
829
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
830
        swa,
831
    ):
zlsh80826's avatar
zlsh80826 committed
832
833
834
        """
        Test backward with parameterized configs
        """
835
836
837
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
838
839
840
841
842
843
844
845
846
847
848
849
850
851
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
852
            window_size,
853
        )
zlsh80826's avatar
zlsh80826 committed
854
        runner.test_backward()