test_fused_attn.py 19.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
16
17
from flax.linen.dtypes import promote_dtype
from jax import Array
18
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
19
from jax.typing import ArrayLike, DTypeLike
20

21
22
23
24
25
26
27
28
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
    fused_attn_qkvpacked,
    fused_attn_kvpacked,
    fused_attn
)
29
30
from transformer_engine.jax.cpp_extensions import FusedAttnHelper

31
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
32
33

from utils import assert_allclose
34

35

36
37
@pytest.fixture(autouse=True, scope='module')
def init():
38
    """
39
    WAR for CUDA uninitialize error
40
    """
zlsh80826's avatar
zlsh80826 committed
41
42
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
43
44
45
    yield


zlsh80826's avatar
zlsh80826 committed
46
47
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
                                  bias: ArrayLike, mask: ArrayLike, deterministic: bool,
48
                                  scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
zlsh80826's avatar
zlsh80826 committed
49
                                  dtype: DTypeLike) -> Array:
50
    """
zlsh80826's avatar
zlsh80826 committed
51
    Similar to flax.linen.dot_product_attention but with GQA support
52
    """
zlsh80826's avatar
zlsh80826 committed
53
54
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
55

zlsh80826's avatar
zlsh80826 committed
56
    b, s_q, h_q, d = query.shape
57
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
58
59
60
61
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
62
    logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
63
64

    if bias is not None:
65
66
67
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
68
        logits = logits + bias
69
70
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
71
72
73
74

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
75
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
76

zlsh80826's avatar
zlsh80826 committed
77
    softmax_out = jax.nn.softmax(logits).astype(dtype)
78

zlsh80826's avatar
zlsh80826 committed
79
80
81
82
83
84
85
86
87
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

    context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
    context = jnp.reshape(context, query.shape)
    return context
88
89


90
91
92
93
94
95
96
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


zlsh80826's avatar
zlsh80826 committed
97
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
98
    """
99
100
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
101
    """
zlsh80826's avatar
zlsh80826 committed
102
103
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
104
105
106
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
    inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
    return combine_masks(inv_causal_mask, inv_padding_mask)
107

108
109
110
111
112
113
114
115
116
117
118
119
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
    if is_causal_mask(attn_mask_type):
        inv_mask = make_decoder_mask(q_token, kv_token)
    else:
        inv_mask = make_attention_mask(q_token > 0, kv_token > 0)
    mask = jnp.logical_not(inv_mask)
    return mask
120

zlsh80826's avatar
zlsh80826 committed
121
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
122
    """
zlsh80826's avatar
zlsh80826 committed
123
    JAX native dot product attention implementation
124
    """
125
    attn_mask_type = kwargs['attn_mask_type']
126
    mask = make_mask(q_token, kv_token, attn_mask_type)
127

zlsh80826's avatar
zlsh80826 committed
128
129
130
131
132
133
    output = general_dot_product_attention(query,
                                           key,
                                           value,
                                           bias=bias,
                                           mask=mask,
                                           deterministic=not kwargs['is_training'],
134
                                           scale_factor=kwargs['scaling_factor'],
zlsh80826's avatar
zlsh80826 committed
135
136
137
138
                                           dropout_rate=kwargs['dropout_probability'],
                                           dropout_rng=dropout_rng,
                                           dtype=jnp.float32)
    return output.astype(query.dtype)
139
140


zlsh80826's avatar
zlsh80826 committed
141
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
142
    """
zlsh80826's avatar
zlsh80826 committed
143
    TE customcall dot product attention implementation
144
    """
145
    attn_mask_type = kwargs['attn_mask_type']
146
    mask = make_mask(q_token, kv_token, attn_mask_type)
zlsh80826's avatar
zlsh80826 committed
147
148
149
150
151
152

    qkv_layout = kwargs.pop('qkv_layout')
    match qkv_layout:
        case QKVLayout.BS3HD:
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
153
            return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
154
155
156
        case QKVLayout.BSHD_BS2HD:
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
157
158
            return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
                                       **kwargs).astype(query.dtype)
159
160
161
        case QKVLayout.BSHD_BSHD_BSHD:
            return fused_attn(query, key, value, bias, mask, dropout_rng,
                              **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
162
163


164
class BiasShape(Enum):
165
166
167
168
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

169
170
171
172
173
174
    BIAS_1HSS = '1HSS'
    BIAS_B1SS = 'B1SS'
    BIAS_BHSS = 'BHSS'
    BIAS_11SS = '11SS'


zlsh80826's avatar
zlsh80826 committed
175
176
@dataclass
class FusedAttnRunner:
177
    """
zlsh80826's avatar
zlsh80826 committed
178
    Fused attention runner
179
    """
zlsh80826's avatar
zlsh80826 committed
180
181
182
183
184
185
186
187
188
189
190
191
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
192
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
193
194
195
196
197
198
199
200

    def _check_configs(self):
        if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
            pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")

        if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
            pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

201
202
203
204
205
        self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
                                       self.attn_bias_type.value, self.attn_mask_type.value,
                                       self.dropout_prob, self.num_heads_q, self.num_heads_kv,
                                       self.max_seqlen_q, self.max_seqlen_kv,
                                       self.head_dim).get_fused_attn_backend()
206
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
207
            pytest.skip("Unsupported inputs combination or device compute capability.")
208

209
210
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
211
212
213
214
215
216
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "the F16_arbitrary_seqlen backend.")

zlsh80826's avatar
zlsh80826 committed
217
218
219
220
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
221

zlsh80826's avatar
zlsh80826 committed
222
223
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
224

225
226
227
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
228
229
230
231
232
233
234
235
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
            bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
236
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
237

238
239
240
241
242
243
244
245
246
247
248
249
250
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
                cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
251
                seq_id_size = max_id * 5 // 128    # 5 ids per interval of 128 sequences
252
253
254
255
256
257
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
                    self.bias = \
                        self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
        else:
            self.bias = None
258

zlsh80826's avatar
zlsh80826 committed
259
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
260
261
262
263
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
264
265
266
267
268
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return valid_len, tokens
269

zlsh80826's avatar
zlsh80826 committed
270
271
        self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
        self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
272

zlsh80826's avatar
zlsh80826 committed
273
274
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
        self.scaling_factor = 1. / sqrt(self.head_dim)
275

zlsh80826's avatar
zlsh80826 committed
276
277
278
279
280
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
281

zlsh80826's avatar
zlsh80826 committed
282
283
284
285
286
287
288
289
290
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': self.attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
        }
291

zlsh80826's avatar
zlsh80826 committed
292
293
294
        # Convert the outputs to float32 for the elementwise comparison
        primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
        reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
295

zlsh80826's avatar
zlsh80826 committed
296
        if self.is_training and self.dropout_prob > 0.:
297
298
            return

299
300
301
302
303
        primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
        reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
304

zlsh80826's avatar
zlsh80826 committed
305
    def test_backward(self):
306
        """
zlsh80826's avatar
zlsh80826 committed
307
        Test value_and_grad with JIT, which includes both forward and backward
308
        """
zlsh80826's avatar
zlsh80826 committed
309
310

        self._setup_inputs()
311
312
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
313
314

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
315
            # Gradient is small, use a gradient multiplier to amplify the gradient
zlsh80826's avatar
zlsh80826 committed
316
317
318
            gradient_multiplier = self.valid_len_q * self.num_heads_q
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
319
            # Keep only valid result for the gradient
zlsh80826's avatar
zlsh80826 committed
320
321
            ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
322

zlsh80826's avatar
zlsh80826 committed
323
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
324
325
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
zlsh80826's avatar
zlsh80826 committed
326
            'attn_mask_type': self.attn_mask_type,
327
            'scaling_factor': self.scaling_factor,
zlsh80826's avatar
zlsh80826 committed
328
329
330
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
331
332
        }

333
334
335
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

336
337
338
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
zlsh80826's avatar
zlsh80826 committed
339
                lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
340
                                                       **kwargs), arg_nums))
341
342
        jitted_reference = jit(
            value_and_grad(
343
344
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
                arg_nums))
345

zlsh80826's avatar
zlsh80826 committed
346
347
        primitive_out, primitive_dgrad = jitted_primitive(*args)
        reference_out, reference_dgrad = jitted_reference(*args)
348

zlsh80826's avatar
zlsh80826 committed
349
        # Skip elementwise comparison when dropout enabled
350
        if self.dropout_prob > 0.0:
351
352
            return

353
354
355
        assert_allclose(primitive_out.astype(jnp.float32),
                        reference_out.astype(jnp.float32),
                        dtype=self.dtype)
356

zlsh80826's avatar
zlsh80826 committed
357
358
359
        def check_dqkv(primitive, reference, valid_len):
            primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
            reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
360

361
362
363
364
365
366
367
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

        # Convert the outputs to float32 for the elementwise comparison
        primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
        reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
368

zlsh80826's avatar
zlsh80826 committed
369
370
371
        check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
        check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
        check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
372

373
374
375
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
            primitive_dbias = jnp.float32(primitive_dgrad[3])
            reference_dbias = jnp.float32(reference_dgrad[3])
376

377
378
379
380
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
                                                           self.valid_len_kv:]),
                            dtype=self.dtype)
381

382
383
384
385
            # dbias padded part
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            dtype=self.dtype)
386

387
388
389
390
391
            # dbias valid part
            assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            dtype=self.dtype)

392
393
394
395
396
397
398

@pytest.mark.parametrize('attn_bias_type, bias_shape', [
    pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
zlsh80826's avatar
zlsh80826 committed
399
400
401
402
403
404
405
406
])
@pytest.mark.parametrize('attn_mask_type', [
    pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
    pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
    pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
    pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
407
408
409
410
411
412
    pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
    pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
    pytest.param(jnp.bfloat16, id="BF16"),
413
    pytest.param(jnp.float16, id="FP16"),
414
])
415
416
417
418
419
420
421
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
    pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
    pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
    pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
    pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
    pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
    pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
422
423
424
])
@pytest.mark.parametrize('dropout_prob', [
    pytest.param(0.0, id="DROP_0.0"),
425
    pytest.param(0.1, id="DROP_0.1"),
zlsh80826's avatar
zlsh80826 committed
426
427
428
429
430
])
class TestFusedAttn:
    """
    Fused attention tester
    """
431

zlsh80826's avatar
zlsh80826 committed
432
    @staticmethod
433
434
435
436
437
438
    @pytest.mark.parametrize('is_training', [
        pytest.param(True, id='TRAINING'),
        pytest.param(False, id='INFERENCE'),
    ])
    def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                     dtype, is_training, qkv_layout, bias_shape):
439
        """
zlsh80826's avatar
zlsh80826 committed
440
        Test forward with parameterized configs
441
        """
zlsh80826's avatar
zlsh80826 committed
442
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
443
                                 dropout_prob, dtype, is_training, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
444
        runner.test_forward()
445

zlsh80826's avatar
zlsh80826 committed
446
    @staticmethod
447
448
    def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                      dtype, qkv_layout, bias_shape):
zlsh80826's avatar
zlsh80826 committed
449
450
451
452
        """
        Test backward with parameterized configs
        """
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
453
                                 dropout_prob, dtype, True, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
454
        runner.test_backward()