test_fused_attn.py 19.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11
12
13
14
15

import jax
import jax.numpy as jnp
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
16
17
from flax.linen.dtypes import promote_dtype
from jax import Array
18
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
19
from jax.typing import ArrayLike, DTypeLike
20

21
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
22
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
23
24
25
26
27
from transformer_engine.jax.cpp_extensions import FusedAttnHelper

from transformer_engine_jax import NVTE_Fused_Attn_Backend

from utils import assert_allclose
28

29

30
31
32
33
34
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
zlsh80826's avatar
zlsh80826 committed
35
36
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
37
38
39
40
41
    yield
    for arr in jax.live_arrays():
        arr.delete()


zlsh80826's avatar
zlsh80826 committed
42
43
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
                                  bias: ArrayLike, mask: ArrayLike, deterministic: bool,
44
                                  scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
zlsh80826's avatar
zlsh80826 committed
45
                                  dtype: DTypeLike) -> Array:
46
    """
zlsh80826's avatar
zlsh80826 committed
47
    Similar to flax.linen.dot_product_attention but with GQA support
48
    """
zlsh80826's avatar
zlsh80826 committed
49
50
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
51

zlsh80826's avatar
zlsh80826 committed
52
    b, s_q, h_q, d = query.shape
53
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
54
55
56
57
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
58
    logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
59
60

    if bias is not None:
61
62
63
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
64
        logits = logits + bias
65
66
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
67
68
69
70
71

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
        logits = jnp.where(mask, logits, jnp.finfo(dtype).min)
72

zlsh80826's avatar
zlsh80826 committed
73
    softmax_out = jax.nn.softmax(logits).astype(dtype)
74

zlsh80826's avatar
zlsh80826 committed
75
76
77
78
79
80
81
82
83
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

    context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
    context = jnp.reshape(context, query.shape)
    return context
84
85


86
87
88
89
90
91
92
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


zlsh80826's avatar
zlsh80826 committed
93
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
94
95
96
    """
    Create padded causal mask
    """
zlsh80826's avatar
zlsh80826 committed
97
98
99
100
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
    causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
    padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
101
102
103
    return combine_masks(causal_mask, padding_mask)


zlsh80826's avatar
zlsh80826 committed
104
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
105
    """
zlsh80826's avatar
zlsh80826 committed
106
    JAX native dot product attention implementation
107
    """
108
    attn_mask_type = kwargs['attn_mask_type']
109
    if is_causal_mask(attn_mask_type):
zlsh80826's avatar
zlsh80826 committed
110
        mask = make_decoder_mask(q_token, kv_token)
111
112
113
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

zlsh80826's avatar
zlsh80826 committed
114
115
116
117
118
119
    output = general_dot_product_attention(query,
                                           key,
                                           value,
                                           bias=bias,
                                           mask=mask,
                                           deterministic=not kwargs['is_training'],
120
                                           scale_factor=kwargs['scaling_factor'],
zlsh80826's avatar
zlsh80826 committed
121
122
123
124
                                           dropout_rate=kwargs['dropout_probability'],
                                           dropout_rng=dropout_rng,
                                           dtype=jnp.float32)
    return output.astype(query.dtype)
125
126


zlsh80826's avatar
zlsh80826 committed
127
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
128
    """
zlsh80826's avatar
zlsh80826 committed
129
    TE customcall dot product attention implementation
130
    """
131
    attn_mask_type = kwargs['attn_mask_type']
132
    if is_causal_mask(attn_mask_type):
zlsh80826's avatar
zlsh80826 committed
133
        mask = make_decoder_mask(q_token, kv_token)
134
135
136
137
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
zlsh80826's avatar
zlsh80826 committed
138
139
140
141
142
143
144
    mask = jnp.logical_not(mask)

    qkv_layout = kwargs.pop('qkv_layout')
    match qkv_layout:
        case QKVLayout.BS3HD:
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
145
            return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
146
147
148
        case QKVLayout.BSHD_BS2HD:
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
149
150
            return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
                                       **kwargs).astype(query.dtype)
151
152
153
        case QKVLayout.BSHD_BSHD_BSHD:
            return fused_attn(query, key, value, bias, mask, dropout_rng,
                              **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
154
155


156
class BiasShape(Enum):
157
158
159
160
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

161
162
163
164
165
166
    BIAS_1HSS = '1HSS'
    BIAS_B1SS = 'B1SS'
    BIAS_BHSS = 'BHSS'
    BIAS_11SS = '11SS'


zlsh80826's avatar
zlsh80826 committed
167
168
@dataclass
class FusedAttnRunner:
169
    """
zlsh80826's avatar
zlsh80826 committed
170
    Fused attention runner
171
    """
zlsh80826's avatar
zlsh80826 committed
172
173
174
175
176
177
178
179
180
181
182
183
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
184
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
185
186
187
188
189
190
191
192

    def _check_configs(self):
        if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
            pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")

        if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
            pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

193
194
195
196
197
        self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
                                       self.attn_bias_type.value, self.attn_mask_type.value,
                                       self.dropout_prob, self.num_heads_q, self.num_heads_kv,
                                       self.max_seqlen_q, self.max_seqlen_kv,
                                       self.head_dim).get_fused_attn_backend()
198
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
199
            pytest.skip("Unsupported inputs combination or device compute capability.")
200

201
202
        if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
            if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
203
204
205
206
207
208
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "the F16_arbitrary_seqlen backend.")

zlsh80826's avatar
zlsh80826 committed
209
210
211
212
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
213

zlsh80826's avatar
zlsh80826 committed
214
215
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
216

217
218
219
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
220
221
222
223
224
225
226
227
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
            bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
228
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
229

230
231
232
233
234
235
236
237
238
239
240
241
242
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
                cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
243
                seq_id_size = max_id * 5 // 128    # 5 ids per interval of 128 sequences
244
245
246
247
248
249
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
                    self.bias = \
                        self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
        else:
            self.bias = None
250

zlsh80826's avatar
zlsh80826 committed
251
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
252
253
254
255
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
256
257
258
259
260
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return valid_len, tokens
261

zlsh80826's avatar
zlsh80826 committed
262
263
        self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
        self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
264

zlsh80826's avatar
zlsh80826 committed
265
266
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
        self.scaling_factor = 1. / sqrt(self.head_dim)
267

zlsh80826's avatar
zlsh80826 committed
268
269
270
271
272
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
273

zlsh80826's avatar
zlsh80826 committed
274
275
276
277
278
279
280
281
282
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': self.attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
        }
283

zlsh80826's avatar
zlsh80826 committed
284
285
286
        # Convert the outputs to float32 for the elementwise comparison
        primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
        reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
287

zlsh80826's avatar
zlsh80826 committed
288
        if self.is_training and self.dropout_prob > 0.:
289
290
            return

291
292
293
294
295
        primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
        reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
296

zlsh80826's avatar
zlsh80826 committed
297
    def test_backward(self):
298
        """
zlsh80826's avatar
zlsh80826 committed
299
        Test value_and_grad with JIT, which includes both forward and backward
300
        """
zlsh80826's avatar
zlsh80826 committed
301
302
303
304

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
305
            # Gradient is small, use a gradient multiplier to amplify the gradient
zlsh80826's avatar
zlsh80826 committed
306
307
308
            gradient_multiplier = self.valid_len_q * self.num_heads_q
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
309
            # Keep only valid result for the gradient
zlsh80826's avatar
zlsh80826 committed
310
311
            ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
312

zlsh80826's avatar
zlsh80826 committed
313
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
314
315
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
zlsh80826's avatar
zlsh80826 committed
316
            'attn_mask_type': self.attn_mask_type,
317
            'scaling_factor': self.scaling_factor,
zlsh80826's avatar
zlsh80826 committed
318
319
320
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
321
322
        }

323
324
325
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

326
327
328
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
zlsh80826's avatar
zlsh80826 committed
329
                lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
330
                                                       **kwargs), arg_nums))
331
332
        jitted_reference = jit(
            value_and_grad(
333
334
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
                arg_nums))
335

zlsh80826's avatar
zlsh80826 committed
336
337
        primitive_out, primitive_dgrad = jitted_primitive(*args)
        reference_out, reference_dgrad = jitted_reference(*args)
338

zlsh80826's avatar
zlsh80826 committed
339
        # Skip elementwise comparison when dropout enabled
340
        if self.dropout_prob > 0.0:
341
342
            return

343
344
345
        assert_allclose(primitive_out.astype(jnp.float32),
                        reference_out.astype(jnp.float32),
                        dtype=self.dtype)
346

zlsh80826's avatar
zlsh80826 committed
347
348
349
        def check_dqkv(primitive, reference, valid_len):
            primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
            reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
350

351
352
353
354
355
356
357
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

        # Convert the outputs to float32 for the elementwise comparison
        primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
        reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
358

zlsh80826's avatar
zlsh80826 committed
359
360
361
        check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
        check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
        check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
362

363
364
365
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
            primitive_dbias = jnp.float32(primitive_dgrad[3])
            reference_dbias = jnp.float32(reference_dgrad[3])
366

367
368
369
370
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
                                                           self.valid_len_kv:]),
                            dtype=self.dtype)
371

372
373
374
375
            # dbias padded part
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            dtype=self.dtype)
376

377
378
379
380
381
            # dbias valid part
            assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            dtype=self.dtype)

382
383
384
385
386
387
388

@pytest.mark.parametrize('attn_bias_type, bias_shape', [
    pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
zlsh80826's avatar
zlsh80826 committed
389
390
391
392
393
394
395
396
])
@pytest.mark.parametrize('attn_mask_type', [
    pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
    pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
    pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
    pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
397
398
399
400
401
402
    pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
    pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
    pytest.param(jnp.bfloat16, id="BF16"),
403
    pytest.param(jnp.float16, id="FP16"),
404
])
405
406
407
408
409
410
411
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
    pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
    pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
    pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
    pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
    pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
    pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
412
413
414
])
@pytest.mark.parametrize('dropout_prob', [
    pytest.param(0.0, id="DROP_0.0"),
415
    pytest.param(0.1, id="DROP_0.1"),
zlsh80826's avatar
zlsh80826 committed
416
417
418
419
420
])
class TestFusedAttn:
    """
    Fused attention tester
    """
421

zlsh80826's avatar
zlsh80826 committed
422
    @staticmethod
423
424
425
426
427
428
    @pytest.mark.parametrize('is_training', [
        pytest.param(True, id='TRAINING'),
        pytest.param(False, id='INFERENCE'),
    ])
    def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                     dtype, is_training, qkv_layout, bias_shape):
429
        """
zlsh80826's avatar
zlsh80826 committed
430
        Test forward with parameterized configs
431
        """
zlsh80826's avatar
zlsh80826 committed
432
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
433
                                 dropout_prob, dtype, is_training, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
434
        runner.test_forward()
435

zlsh80826's avatar
zlsh80826 committed
436
    @staticmethod
437
438
    def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                      dtype, qkv_layout, bias_shape):
zlsh80826's avatar
zlsh80826 committed
439
440
441
442
        """
        Test backward with parameterized configs
        """
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
443
                                 dropout_prob, dtype, True, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
444
        runner.test_backward()