test_fused_attn.py 20 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
16
17
from flax.linen.dtypes import promote_dtype
from jax import Array
18
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
19
from jax.typing import ArrayLike, DTypeLike
20

21
22
23
24
25
26
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
    fused_attn_qkvpacked,
    fused_attn_kvpacked,
27
    fused_attn,
28
)
29
30
from transformer_engine.jax.cpp_extensions import FusedAttnHelper

31
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
32
33

from utils import assert_allclose
34

35

36
@pytest.fixture(autouse=True, scope="module")
37
def init():
38
    """
39
    WAR for CUDA uninitialize error
40
    """
zlsh80826's avatar
zlsh80826 committed
41
42
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
43
44
45
    yield


46
47
48
49
50
51
52
53
54
55
56
57
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
58
    """
zlsh80826's avatar
zlsh80826 committed
59
    Similar to flax.linen.dot_product_attention but with GQA support
60
    """
zlsh80826's avatar
zlsh80826 committed
61
62
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
63

zlsh80826's avatar
zlsh80826 committed
64
    b, s_q, h_q, d = query.shape
65
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
66
67
68
69
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
70
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
71
72

    if bias is not None:
73
74
75
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
76
        logits = logits + bias
77
78
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
79
80
81
82

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
83
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
84

zlsh80826's avatar
zlsh80826 committed
85
    softmax_out = jax.nn.softmax(logits).astype(dtype)
86

87
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
88
89
90
91
92
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

93
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
94
95
    context = jnp.reshape(context, query.shape)
    return context
96
97


98
99
100
101
102
103
104
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


zlsh80826's avatar
zlsh80826 committed
105
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
106
    """
107
108
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
109
    """
zlsh80826's avatar
zlsh80826 committed
110
111
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
112
113
114
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
    inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
    return combine_masks(inv_causal_mask, inv_padding_mask)
115

116

117
118
119
120
121
122
123
124
125
126
127
128
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
    if is_causal_mask(attn_mask_type):
        inv_mask = make_decoder_mask(q_token, kv_token)
    else:
        inv_mask = make_attention_mask(q_token > 0, kv_token > 0)
    mask = jnp.logical_not(inv_mask)
    return mask
129

130

zlsh80826's avatar
zlsh80826 committed
131
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
132
    """
zlsh80826's avatar
zlsh80826 committed
133
    JAX native dot product attention implementation
134
    """
135
    attn_mask_type = kwargs["attn_mask_type"]
136
    mask = make_mask(q_token, kv_token, attn_mask_type)
137

138
139
140
141
142
143
144
145
146
147
148
149
    output = general_dot_product_attention(
        query,
        key,
        value,
        bias=bias,
        mask=mask,
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
150
    return output.astype(query.dtype)
151
152


zlsh80826's avatar
zlsh80826 committed
153
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
154
    """
zlsh80826's avatar
zlsh80826 committed
155
    TE customcall dot product attention implementation
156
    """
157
    attn_mask_type = kwargs["attn_mask_type"]
158
    mask = make_mask(q_token, kv_token, attn_mask_type)
zlsh80826's avatar
zlsh80826 committed
159

160
    qkv_layout = kwargs.pop("qkv_layout")
zlsh80826's avatar
zlsh80826 committed
161
162
163
164
    match qkv_layout:
        case QKVLayout.BS3HD:
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
165
            return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
166
167
168
        case QKVLayout.BSHD_BS2HD:
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
169
170
171
            return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype(
                query.dtype
            )
172
        case QKVLayout.BSHD_BSHD_BSHD:
173
174
175
            return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(
                query.dtype
            )
zlsh80826's avatar
zlsh80826 committed
176
177


178
class BiasShape(Enum):
179
180
181
182
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

183
184
185
186
    BIAS_1HSS = "1HSS"
    BIAS_B1SS = "B1SS"
    BIAS_BHSS = "BHSS"
    BIAS_11SS = "11SS"
187
188


zlsh80826's avatar
zlsh80826 committed
189
190
@dataclass
class FusedAttnRunner:
191
    """
zlsh80826's avatar
zlsh80826 committed
192
    Fused attention runner
193
    """
194

zlsh80826's avatar
zlsh80826 committed
195
196
197
198
199
200
201
202
203
204
205
206
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
207
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
208
209
210
211
212
213
214
215

    def _check_configs(self):
        if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
            pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")

        if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
            pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

216
217
218
219
220
221
222
223
224
225
226
227
228
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
        ).get_fused_attn_backend()
229
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
230
            pytest.skip("Unsupported inputs combination or device compute capability.")
231

232
233
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
234
235
236
237
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
                )
238
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
239
240
241
242
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
243

zlsh80826's avatar
zlsh80826 committed
244
245
246
247
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
248

zlsh80826's avatar
zlsh80826 committed
249
250
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
251

252
253
254
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
255
256
257
258
259
260
261
262
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
            bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
263
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
264

265
266
267
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
268
269
270

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
271
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
272
273
274
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
275
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
276
277
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
278
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
279
280
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
281
282
283
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
284
285
        else:
            self.bias = None
286

zlsh80826's avatar
zlsh80826 committed
287
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
288
289
290
291
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
292
293
294
295
296
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return valid_len, tokens
297

zlsh80826's avatar
zlsh80826 committed
298
299
        self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
        self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
300

zlsh80826's avatar
zlsh80826 committed
301
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
302
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
303

zlsh80826's avatar
zlsh80826 committed
304
305
306
307
308
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
309

zlsh80826's avatar
zlsh80826 committed
310
311
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
        kwargs = {
312
313
314
315
316
317
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
zlsh80826's avatar
zlsh80826 committed
318
        }
319

zlsh80826's avatar
zlsh80826 committed
320
321
322
        # Convert the outputs to float32 for the elementwise comparison
        primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
        reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
323

324
        if self.is_training and self.dropout_prob > 0.0:
325
326
            return

327
328
329
330
331
        primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
        reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
332

zlsh80826's avatar
zlsh80826 committed
333
    def test_backward(self):
334
        """
zlsh80826's avatar
zlsh80826 committed
335
        Test value_and_grad with JIT, which includes both forward and backward
336
        """
zlsh80826's avatar
zlsh80826 committed
337
338

        self._setup_inputs()
339
340
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
341
342

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
343
            # Gradient is small, use a gradient multiplier to amplify the gradient
zlsh80826's avatar
zlsh80826 committed
344
345
346
            gradient_multiplier = self.valid_len_q * self.num_heads_q
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
347
            # Keep only valid result for the gradient
zlsh80826's avatar
zlsh80826 committed
348
349
            ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
350

zlsh80826's avatar
zlsh80826 committed
351
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
352
        kwargs = {
353
354
355
356
357
358
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
359
360
        }

361
362
363
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

364
365
366
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
367
368
369
370
371
372
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
373
374
        jitted_reference = jit(
            value_and_grad(
375
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
376
377
378
                arg_nums,
            )
        )
379

zlsh80826's avatar
zlsh80826 committed
380
381
        primitive_out, primitive_dgrad = jitted_primitive(*args)
        reference_out, reference_dgrad = jitted_reference(*args)
382

zlsh80826's avatar
zlsh80826 committed
383
        # Skip elementwise comparison when dropout enabled
384
        if self.dropout_prob > 0.0:
385
386
            return

387
388
389
        assert_allclose(
            primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
        )
390

zlsh80826's avatar
zlsh80826 committed
391
392
393
        def check_dqkv(primitive, reference, valid_len):
            primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
            reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
394

395
396
397
398
399
400
401
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

        # Convert the outputs to float32 for the elementwise comparison
        primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
        reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
402

zlsh80826's avatar
zlsh80826 committed
403
404
405
        check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
        check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
        check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
406

407
408
409
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
            primitive_dbias = jnp.float32(primitive_dgrad[3])
            reference_dbias = jnp.float32(reference_dgrad[3])
410

411
412
413
414
415
            assert_allclose(
                primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
                jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]),
                dtype=self.dtype,
            )
416

417
            # dbias padded part
418
419
420
421
422
            assert_allclose(
                primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
                reference_dbias[..., self.valid_len_q :, self.valid_len_kv :],
                dtype=self.dtype,
            )
423

424
            # dbias valid part
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
            assert_allclose(
                primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv],
                reference_dbias[..., : self.valid_len_q, : self.valid_len_kv],
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_bias_type, bias_shape",
    [
        pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
    ],
)
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [
        pytest.param(jnp.bfloat16, id="BF16"),
        pytest.param(jnp.float16, id="FP16"),
    ],
)
@pytest.mark.parametrize(
    "b, s_q, s_kv, h_q, h_kv, d",
    [
        pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"),
        pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"),
        pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"),
        pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"),
        pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"),
        pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"),
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
484
485
486
487
class TestFusedAttn:
    """
    Fused attention tester
    """
488

zlsh80826's avatar
zlsh80826 committed
489
    @staticmethod
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
    def test_forward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
    ):
512
        """
zlsh80826's avatar
zlsh80826 committed
513
        Test forward with parameterized configs
514
        """
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
        )
zlsh80826's avatar
zlsh80826 committed
530
        runner.test_forward()
531

zlsh80826's avatar
zlsh80826 committed
532
    @staticmethod
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
    ):
zlsh80826's avatar
zlsh80826 committed
547
548
549
        """
        Test backward with parameterized configs
        """
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
        )
zlsh80826's avatar
zlsh80826 committed
565
        runner.test_backward()