test_fused_attn.py 19.6 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
16
17
from flax.linen.dtypes import promote_dtype
from jax import Array
18
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
19
from jax.typing import ArrayLike, DTypeLike
20

21
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
22
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
23
24
25
26
27
from transformer_engine.jax.cpp_extensions import FusedAttnHelper

from transformer_engine_jax import NVTE_Fused_Attn_Backend

from utils import assert_allclose
28

29

30
31
@pytest.fixture(autouse=True, scope='module')
def init():
32
    """
33
    WAR for CUDA uninitialize error
34
    """
zlsh80826's avatar
zlsh80826 committed
35
36
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
37
38
39
    yield


zlsh80826's avatar
zlsh80826 committed
40
41
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
                                  bias: ArrayLike, mask: ArrayLike, deterministic: bool,
42
                                  scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
zlsh80826's avatar
zlsh80826 committed
43
                                  dtype: DTypeLike) -> Array:
44
    """
zlsh80826's avatar
zlsh80826 committed
45
    Similar to flax.linen.dot_product_attention but with GQA support
46
    """
zlsh80826's avatar
zlsh80826 committed
47
48
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
49

zlsh80826's avatar
zlsh80826 committed
50
    b, s_q, h_q, d = query.shape
51
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
52
53
54
55
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
56
    logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
57
58

    if bias is not None:
59
60
61
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
62
        logits = logits + bias
63
64
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
65
66
67
68

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
69
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
70

zlsh80826's avatar
zlsh80826 committed
71
    softmax_out = jax.nn.softmax(logits).astype(dtype)
72

zlsh80826's avatar
zlsh80826 committed
73
74
75
76
77
78
79
80
81
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

    context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
    context = jnp.reshape(context, query.shape)
    return context
82
83


84
85
86
87
88
89
90
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


zlsh80826's avatar
zlsh80826 committed
91
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
92
    """
93
94
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
95
    """
zlsh80826's avatar
zlsh80826 committed
96
97
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
98
99
100
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
    inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
    return combine_masks(inv_causal_mask, inv_padding_mask)
101

102
103
104
105
106
107
108
109
110
111
112
113
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
    if is_causal_mask(attn_mask_type):
        inv_mask = make_decoder_mask(q_token, kv_token)
    else:
        inv_mask = make_attention_mask(q_token > 0, kv_token > 0)
    mask = jnp.logical_not(inv_mask)
    return mask
114

zlsh80826's avatar
zlsh80826 committed
115
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
116
    """
zlsh80826's avatar
zlsh80826 committed
117
    JAX native dot product attention implementation
118
    """
119
    attn_mask_type = kwargs['attn_mask_type']
120
    mask = make_mask(q_token, kv_token, attn_mask_type)
121

zlsh80826's avatar
zlsh80826 committed
122
123
124
125
126
127
    output = general_dot_product_attention(query,
                                           key,
                                           value,
                                           bias=bias,
                                           mask=mask,
                                           deterministic=not kwargs['is_training'],
128
                                           scale_factor=kwargs['scaling_factor'],
zlsh80826's avatar
zlsh80826 committed
129
130
131
132
                                           dropout_rate=kwargs['dropout_probability'],
                                           dropout_rng=dropout_rng,
                                           dtype=jnp.float32)
    return output.astype(query.dtype)
133
134


zlsh80826's avatar
zlsh80826 committed
135
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
136
    """
zlsh80826's avatar
zlsh80826 committed
137
    TE customcall dot product attention implementation
138
    """
139
    attn_mask_type = kwargs['attn_mask_type']
140
    mask = make_mask(q_token, kv_token, attn_mask_type)
zlsh80826's avatar
zlsh80826 committed
141
142
143
144
145
146

    qkv_layout = kwargs.pop('qkv_layout')
    match qkv_layout:
        case QKVLayout.BS3HD:
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
147
            return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
148
149
150
        case QKVLayout.BSHD_BS2HD:
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
151
152
            return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
                                       **kwargs).astype(query.dtype)
153
154
155
        case QKVLayout.BSHD_BSHD_BSHD:
            return fused_attn(query, key, value, bias, mask, dropout_rng,
                              **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
156
157


158
class BiasShape(Enum):
159
160
161
162
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

163
164
165
166
167
168
    BIAS_1HSS = '1HSS'
    BIAS_B1SS = 'B1SS'
    BIAS_BHSS = 'BHSS'
    BIAS_11SS = '11SS'


zlsh80826's avatar
zlsh80826 committed
169
170
@dataclass
class FusedAttnRunner:
171
    """
zlsh80826's avatar
zlsh80826 committed
172
    Fused attention runner
173
    """
zlsh80826's avatar
zlsh80826 committed
174
175
176
177
178
179
180
181
182
183
184
185
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
186
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
187
188
189
190
191
192
193
194

    def _check_configs(self):
        if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
            pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")

        if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
            pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

195
196
197
198
199
        self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
                                       self.attn_bias_type.value, self.attn_mask_type.value,
                                       self.dropout_prob, self.num_heads_q, self.num_heads_kv,
                                       self.max_seqlen_q, self.max_seqlen_kv,
                                       self.head_dim).get_fused_attn_backend()
200
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
201
            pytest.skip("Unsupported inputs combination or device compute capability.")
202

203
204
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
205
206
207
208
209
210
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "the F16_arbitrary_seqlen backend.")

zlsh80826's avatar
zlsh80826 committed
211
212
213
214
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
215

zlsh80826's avatar
zlsh80826 committed
216
217
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
218

219
220
221
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
222
223
224
225
226
227
228
229
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
            bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
230
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
231

232
233
234
235
236
237
238
239
240
241
242
243
244
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
                cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
245
                seq_id_size = max_id * 5 // 128    # 5 ids per interval of 128 sequences
246
247
248
249
250
251
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
                    self.bias = \
                        self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
        else:
            self.bias = None
252

zlsh80826's avatar
zlsh80826 committed
253
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
254
255
256
257
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
258
259
260
261
262
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return valid_len, tokens
263

zlsh80826's avatar
zlsh80826 committed
264
265
        self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
        self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
266

zlsh80826's avatar
zlsh80826 committed
267
268
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
        self.scaling_factor = 1. / sqrt(self.head_dim)
269

zlsh80826's avatar
zlsh80826 committed
270
271
272
273
274
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
275

zlsh80826's avatar
zlsh80826 committed
276
277
278
279
280
281
282
283
284
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': self.attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
        }
285

zlsh80826's avatar
zlsh80826 committed
286
287
288
        # Convert the outputs to float32 for the elementwise comparison
        primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
        reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
289

zlsh80826's avatar
zlsh80826 committed
290
        if self.is_training and self.dropout_prob > 0.:
291
292
            return

293
294
295
296
297
        primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
        reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
298

zlsh80826's avatar
zlsh80826 committed
299
    def test_backward(self):
300
        """
zlsh80826's avatar
zlsh80826 committed
301
        Test value_and_grad with JIT, which includes both forward and backward
302
        """
zlsh80826's avatar
zlsh80826 committed
303
304

        self._setup_inputs()
305
306
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
307
308

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
309
            # Gradient is small, use a gradient multiplier to amplify the gradient
zlsh80826's avatar
zlsh80826 committed
310
311
312
            gradient_multiplier = self.valid_len_q * self.num_heads_q
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
313
            # Keep only valid result for the gradient
zlsh80826's avatar
zlsh80826 committed
314
315
            ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
316

zlsh80826's avatar
zlsh80826 committed
317
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
318
319
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
zlsh80826's avatar
zlsh80826 committed
320
            'attn_mask_type': self.attn_mask_type,
321
            'scaling_factor': self.scaling_factor,
zlsh80826's avatar
zlsh80826 committed
322
323
324
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
325
326
        }

327
328
329
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

330
331
332
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
zlsh80826's avatar
zlsh80826 committed
333
                lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
334
                                                       **kwargs), arg_nums))
335
336
        jitted_reference = jit(
            value_and_grad(
337
338
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
                arg_nums))
339

zlsh80826's avatar
zlsh80826 committed
340
341
        primitive_out, primitive_dgrad = jitted_primitive(*args)
        reference_out, reference_dgrad = jitted_reference(*args)
342

zlsh80826's avatar
zlsh80826 committed
343
        # Skip elementwise comparison when dropout enabled
344
        if self.dropout_prob > 0.0:
345
346
            return

347
348
349
        assert_allclose(primitive_out.astype(jnp.float32),
                        reference_out.astype(jnp.float32),
                        dtype=self.dtype)
350

zlsh80826's avatar
zlsh80826 committed
351
352
353
        def check_dqkv(primitive, reference, valid_len):
            primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
            reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
354

355
356
357
358
359
360
361
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

        # Convert the outputs to float32 for the elementwise comparison
        primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
        reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
362

zlsh80826's avatar
zlsh80826 committed
363
364
365
        check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
        check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
        check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
366

367
368
369
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
            primitive_dbias = jnp.float32(primitive_dgrad[3])
            reference_dbias = jnp.float32(reference_dgrad[3])
370

371
372
373
374
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
                                                           self.valid_len_kv:]),
                            dtype=self.dtype)
375

376
377
378
379
            # dbias padded part
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            dtype=self.dtype)
380

381
382
383
384
385
            # dbias valid part
            assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            dtype=self.dtype)

386
387
388
389
390
391
392

@pytest.mark.parametrize('attn_bias_type, bias_shape', [
    pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
zlsh80826's avatar
zlsh80826 committed
393
394
395
396
397
398
399
400
])
@pytest.mark.parametrize('attn_mask_type', [
    pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
    pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
    pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
    pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
401
402
403
404
405
406
    pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
    pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
    pytest.param(jnp.bfloat16, id="BF16"),
407
    pytest.param(jnp.float16, id="FP16"),
408
])
409
410
411
412
413
414
415
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
    pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
    pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
    pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
    pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
    pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
    pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
416
417
418
])
@pytest.mark.parametrize('dropout_prob', [
    pytest.param(0.0, id="DROP_0.0"),
419
    pytest.param(0.1, id="DROP_0.1"),
zlsh80826's avatar
zlsh80826 committed
420
421
422
423
424
])
class TestFusedAttn:
    """
    Fused attention tester
    """
425

zlsh80826's avatar
zlsh80826 committed
426
    @staticmethod
427
428
429
430
431
432
    @pytest.mark.parametrize('is_training', [
        pytest.param(True, id='TRAINING'),
        pytest.param(False, id='INFERENCE'),
    ])
    def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                     dtype, is_training, qkv_layout, bias_shape):
433
        """
zlsh80826's avatar
zlsh80826 committed
434
        Test forward with parameterized configs
435
        """
zlsh80826's avatar
zlsh80826 committed
436
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
437
                                 dropout_prob, dtype, is_training, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
438
        runner.test_forward()
439

zlsh80826's avatar
zlsh80826 committed
440
    @staticmethod
441
442
    def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                      dtype, qkv_layout, bias_shape):
zlsh80826's avatar
zlsh80826 committed
443
444
445
446
        """
        Test backward with parameterized configs
        """
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
447
                                 dropout_prob, dtype, True, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
448
        runner.test_backward()