test_fused_attn.py 28.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional
10
11
12

import jax
import jax.numpy as jnp
13
import numpy as np
14
15
16
17
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
18
19
from flax.linen.dtypes import promote_dtype
from jax import Array
20
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
21
from jax.typing import ArrayLike, DTypeLike
22

23
24
25
26
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
27
    QKVFormat,
28
    fused_attn,
29
30
    fused_attn_thd,
    get_qkv_format,
31
    make_swa_mask,
32
)
33
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
34
35
36
37
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
38
39

from utils import assert_allclose
40

41

42
@pytest.fixture(autouse=True, scope="module")
43
def init():
44
    """
45
    WAR for CUDA uninitialize error
46
    """
zlsh80826's avatar
zlsh80826 committed
47
48
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
49
50
51
    yield


52
53
54
55
56
57
58
59
60
61
62
63
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
64
    """
zlsh80826's avatar
zlsh80826 committed
65
    Similar to flax.linen.dot_product_attention but with GQA support
66
    """
zlsh80826's avatar
zlsh80826 committed
67
68
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
69

zlsh80826's avatar
zlsh80826 committed
70
    b, s_q, h_q, d = query.shape
71
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
72
73
74
75
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
76
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
77
78

    if bias is not None:
79
80
81
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
82
        logits = logits + bias
83
84
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
85
86
87
88

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
89
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
90

zlsh80826's avatar
zlsh80826 committed
91
    softmax_out = jax.nn.softmax(logits).astype(dtype)
92

93
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
94
95
96
97
98
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

99
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
100
101
    context = jnp.reshape(context, query.shape)
    return context
102
103


104
105
106
107
108
109
110
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


111
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
112
    """
113
114
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
115
    """
zlsh80826's avatar
zlsh80826 committed
116
117
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
118
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
119
    return inv_causal_mask
120

121

122
123
124
125
126
127
def make_mask(
    q_token: ArrayLike,
    kv_token: ArrayLike,
    segment_pad_q: ArrayLike,
    segment_pad_kv: ArrayLike,
    attn_mask_type: AttnMaskType,
128
    window_size: Optional[Tuple[int, int]] = None,
129
) -> Array:
130
131
132
133
134
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
135
136
137
    inv_mask = make_attention_mask(
        q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
    )
138
    if is_causal_mask(attn_mask_type):
139
140
141
142
143
144
145
        inv_causal_mask = make_causal_mask(q_token, kv_token)
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
    if segment_pad_q is not None and segment_pad_kv is not None:
        inv_pad_mask = make_attention_mask(
            segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
        )
        inv_mask = combine_masks(inv_pad_mask, inv_mask)
146
147
148
149
150
151
152
153
154

    if window_size is not None:
        max_seqlen_q = inv_mask.shape[-2]
        max_seqlen_kv = inv_mask.shape[-1]
        inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
        inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
        # In inv_swa_mask and inv_mask 0 is masked out
        inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)

155
156
    mask = jnp.logical_not(inv_mask)
    return mask
157

158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def get_seqlens_and_offsets(segment_ids, segment_pad):
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
        first_column = jnp.ones((x.shape[0], 1), dtype=bool)
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
    if segment_pad is not None:
        segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
        padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
        output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
    else:
        output = jnp.insert(seqlens, -1, values=0, axis=-1)
    return output, offsets


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
195
    """
zlsh80826's avatar
zlsh80826 committed
196
    JAX native dot product attention implementation
197
    """
198
199
200
201
202
203
204
205
206
207
208
209
    output = general_dot_product_attention(
        query,
        key,
        value,
        bias=bias,
        mask=mask,
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
210
    return output.astype(query.dtype)
211
212


213
214
215
216
217
218
219
220
221
222
223
224
225
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
226
    """
zlsh80826's avatar
zlsh80826 committed
227
    TE customcall dot product attention implementation
228
    """
229
230
    qkv_layout = kwargs["qkv_layout"]
    is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
zlsh80826's avatar
zlsh80826 committed
231
    match qkv_layout:
232
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
233
234
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
235
236
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
237
238
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
239
240
241
242
243
244
245
246
247
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    if not is_thd:
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
248
249
250
251
252
253
254
255
        qkv_args,
        bias,
        seqlens_q,
        seqlens_kv,
        offsets_q,
        offsets_kv,
        dropout_rng,
        **kwargs,
256
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
257
258


259
class BiasShape(Enum):
260
261
262
263
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

264
265
266
267
    BIAS_1HSS = "1HSS"
    BIAS_B1SS = "B1SS"
    BIAS_BHSS = "BHSS"
    BIAS_11SS = "11SS"
268
269


zlsh80826's avatar
zlsh80826 committed
270
271
@dataclass
class FusedAttnRunner:
272
    """
zlsh80826's avatar
zlsh80826 committed
273
    Fused attention runner
274
    """
275

zlsh80826's avatar
zlsh80826 committed
276
277
278
279
280
281
282
283
284
285
286
287
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
288
    bias_shape: BiasShape
289
    window_size: Optional[Tuple[int, int]] = None
zlsh80826's avatar
zlsh80826 committed
290

291
292
293
294
295
296
297
298
299
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
        if 90400 <= get_cudnn_version() < 90500:
            return self.num_segments_per_seq
        else:
            # +1 for testing runtime_segments < max_segments
            return self.num_segments_per_seq + 1

zlsh80826's avatar
zlsh80826 committed
300
    def _check_configs(self):
301
302
303
304
305
306
307
        # TODO(rewang): probably adds this in is_fused_attn_available
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
        ]:
            pytest.skip("THD format requires padding masks.")

Reese Wang's avatar
Reese Wang committed
308
309
        qkv_format = get_qkv_format(self.qkv_layout)
        if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD:
310
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
311
312
313
314
315
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")

        if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD:
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
316

317
318
319
320
321
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

322
323
324
325
326
327
328
329
330
331
332
333
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
334
            (-1, -1) if self.window_size is None else self.window_size,
335
        ).get_fused_attn_backend()
336
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
337
            pytest.skip("Unsupported inputs combination or device compute capability.")
338

339
340
341
342
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
            and self.bias_shape != BiasShape.BIAS_1HSS
        ):
343
344
345
346
            if self.attn_mask_type not in [
                AttnMaskType.NO_MASK,
                AttnMaskType.CAUSAL_MASK,
            ]:
347
348
349
350
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
                )
351
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
352
353
354
355
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
356

zlsh80826's avatar
zlsh80826 committed
357
358
359
360
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
361

zlsh80826's avatar
zlsh80826 committed
362
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
363
364
365
366
367
368
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
369

370
371
372
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
373
374
375
376
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
377
378
379
380
381
382
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
383
384
385
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
386
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
387

388
389
390
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
391
392
393

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
394
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
395
396
397
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
398
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
399
400
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
401
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
402
403
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
404
405
406
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
407
408
        else:
            self.bias = None
409

zlsh80826's avatar
zlsh80826 committed
410
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
411
412
413
414
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
415
416
417
418
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
            batch_size, sequence_length, num_segments, seed, with_segment_pad=True
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

                for _ in range(num_segments):
                    segment_size = rng.integers(1, max_segment_size + 1)
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
                    if with_segment_pad:
                        num_valid = rng.integers(1, segment_size + 1)
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1
            return segment_ids, segment_pad

        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
451
            self.num_segments_per_seq = 2
452
453
454
455
456
457
458
459
460
461
            self.token_q, self.segment_pad_q = generate_random_segment_ids(
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
            # TODO(rewang): Check if qkvpacked supported different q/kv
            # TODO(rewang): Causal with different q/kv segment_id fails
            if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
                self.token_kv = self.token_q
                self.segment_pad_kv = self.segment_pad_q
            else:
                self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
462
463
464
465
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
                )
            self.pad_q = self.segment_pad_q
            self.pad_kv = self.segment_pad_kv
        else:
            self.num_segments_per_seq = 1
            self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
            self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
            self.segment_pad_q = self.segment_pad_kv = None

        self.mask = make_mask(
            self.token_q,
            self.token_kv,
            self.segment_pad_q,
            self.segment_pad_kv,
            self.attn_mask_type,
481
            self.window_size,
482
        )
483

484
485
486
487
488
489
490
491
492
493
494
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
                self.token_q, self.segment_pad_q
            )
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
                self.token_kv, self.segment_pad_kv
            )
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
            self.mask_for_customcall = self.mask
495

zlsh80826's avatar
zlsh80826 committed
496
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
497
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
498

zlsh80826's avatar
zlsh80826 committed
499
500
501
502
503
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
504

505
506
507
508
509
510
511
512
513
514
515
516
517
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
zlsh80826's avatar
zlsh80826 committed
518
        kwargs = {
519
520
521
522
523
524
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
525
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
526
            "window_size": self.window_size,
zlsh80826's avatar
zlsh80826 committed
527
        }
528

zlsh80826's avatar
zlsh80826 committed
529
        # Convert the outputs to float32 for the elementwise comparison
530
531
        primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
        reference_out = jax_dpa(*args, **kwargs)
532

533
        if self.is_training and self.dropout_prob > 0.0:
534
535
            return

536
537
538
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
539
540
541

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
542

zlsh80826's avatar
zlsh80826 committed
543
    def test_backward(self):
544
        """
zlsh80826's avatar
zlsh80826 committed
545
        Test value_and_grad with JIT, which includes both forward and backward
546
        """
zlsh80826's avatar
zlsh80826 committed
547
548

        self._setup_inputs()
549
550
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
551
552

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
553
            # Gradient is small, use a gradient multiplier to amplify the gradient
554
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
zlsh80826's avatar
zlsh80826 committed
555
556
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
557
            # Keep only valid result for the gradient
558
559
560
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
zlsh80826's avatar
zlsh80826 committed
561
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
562

563
564
565
566
567
568
569
570
571
572
573
574
575
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
576
        kwargs = {
577
578
579
580
581
582
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
583
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
584
            "window_size": self.window_size,
585
586
        }

587
588
589
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

590
591
592
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
593
594
595
596
597
598
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
599
600
        jitted_reference = jit(
            value_and_grad(
601
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
602
603
604
                arg_nums,
            )
        )
605

606
        primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
zlsh80826's avatar
zlsh80826 committed
607
        reference_out, reference_dgrad = jitted_reference(*args)
608

zlsh80826's avatar
zlsh80826 committed
609
        # Skip elementwise comparison when dropout enabled
610
        if self.dropout_prob > 0.0:
611
612
            return

613
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
614

615
616
617
618
        def check_dqkv(primitive, reference, pad):
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
619

620
621
622
623
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

624
625
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
626

627
628
629
        check_dqkv(primitive_dq, reference_dq, self.pad_q)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv)
630

631
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
632
633
634
635
636
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
637

638
            # Assert all masked dbias are 0s
639
            assert_allclose(
640
641
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
642
643
                dtype=self.dtype,
            )
644

645
            # dbias padded part
646
            assert_allclose(
647
648
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
649
650
                dtype=self.dtype,
            )
651

652
            # dbias valid part
653
            assert_allclose(
654
655
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_bias_type, bias_shape",
    [
        pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
    ],
)
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
685
686
687
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
688
689
690
    ],
)
@pytest.mark.parametrize(
691
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
692
    [
693
694
695
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
696
        pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
697
698
699
700
701
702
703
704
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
705
            id="2-2048-1024-12-12-64-BF16-CROSS",
706
        ),
707
708
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
709
710
711
712
713
714
715
716
717
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
718
719
720
721
722
723
724
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
725
726
727
728
class TestFusedAttn:
    """
    Fused attention tester
    """
729

zlsh80826's avatar
zlsh80826 committed
730
    @staticmethod
731
732
733
734
735
736
737
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
738
    def _test_forward(
739
740
741
742
743
744
745
746
747
748
749
750
751
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
752
        swa,
753
    ):
754
        """
zlsh80826's avatar
zlsh80826 committed
755
        Test forward with parameterized configs
756
757
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
758
        """
759
760
761
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
776
            window_size,
777
        )
zlsh80826's avatar
zlsh80826 committed
778
        runner.test_forward()
779

zlsh80826's avatar
zlsh80826 committed
780
    @staticmethod
781
782
783
784
785
786
787
788
789
790
791
792
793
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
794
        swa,
795
    ):
zlsh80826's avatar
zlsh80826 committed
796
797
798
        """
        Test backward with parameterized configs
        """
799
800
801
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
802
803
804
805
806
807
808
809
810
811
812
813
814
815
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
816
            window_size,
817
        )
zlsh80826's avatar
zlsh80826 committed
818
        runner.test_backward()