test_fused_attn.py 18.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
import sys
6

7
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
8
9
from dataclasses import dataclass
from functools import partial
10
from math import sqrt
11
12
13
14
15
16
17

import jax
import jax.numpy as jnp
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
18
19
from flax.linen.dtypes import promote_dtype
from jax import Array
20
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
21
from jax.typing import ArrayLike, DTypeLike
22

23
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
24
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
25
26
27
28
29
from transformer_engine.jax.cpp_extensions import FusedAttnHelper

from transformer_engine_jax import NVTE_Fused_Attn_Backend

from utils import assert_allclose
30

31

32
33
34
35
36
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
zlsh80826's avatar
zlsh80826 committed
37
38
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
39
40
41
42
43
    yield
    for arr in jax.live_arrays():
        arr.delete()


zlsh80826's avatar
zlsh80826 committed
44
45
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
                                  bias: ArrayLike, mask: ArrayLike, deterministic: bool,
46
                                  scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
zlsh80826's avatar
zlsh80826 committed
47
                                  dtype: DTypeLike) -> Array:
48
    """
zlsh80826's avatar
zlsh80826 committed
49
    Similar to flax.linen.dot_product_attention but with GQA support
50
    """
zlsh80826's avatar
zlsh80826 committed
51
52
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
53

zlsh80826's avatar
zlsh80826 committed
54
    b, s_q, h_q, d = query.shape
55
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
56
57
58
59
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
60
    logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
61
62

    if bias is not None:
63
64
65
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
66
        logits = logits + bias
67
68
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
69
70
71
72
73

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
        logits = jnp.where(mask, logits, jnp.finfo(dtype).min)
74

zlsh80826's avatar
zlsh80826 committed
75
    softmax_out = jax.nn.softmax(logits).astype(dtype)
76

zlsh80826's avatar
zlsh80826 committed
77
78
79
80
81
82
83
84
85
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

    context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
    context = jnp.reshape(context, query.shape)
    return context
86
87


88
89
90
91
92
93
94
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


zlsh80826's avatar
zlsh80826 committed
95
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
96
97
98
    """
    Create padded causal mask
    """
zlsh80826's avatar
zlsh80826 committed
99
100
101
102
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
    causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
    padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
103
104
105
    return combine_masks(causal_mask, padding_mask)


zlsh80826's avatar
zlsh80826 committed
106
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
107
    """
zlsh80826's avatar
zlsh80826 committed
108
    JAX native dot product attention implementation
109
    """
110
    attn_mask_type = kwargs['attn_mask_type']
111
    if is_causal_mask(attn_mask_type):
zlsh80826's avatar
zlsh80826 committed
112
        mask = make_decoder_mask(q_token, kv_token)
113
114
115
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

zlsh80826's avatar
zlsh80826 committed
116
117
118
119
120
121
    output = general_dot_product_attention(query,
                                           key,
                                           value,
                                           bias=bias,
                                           mask=mask,
                                           deterministic=not kwargs['is_training'],
122
                                           scale_factor=kwargs['scaling_factor'],
zlsh80826's avatar
zlsh80826 committed
123
124
125
126
                                           dropout_rate=kwargs['dropout_probability'],
                                           dropout_rng=dropout_rng,
                                           dtype=jnp.float32)
    return output.astype(query.dtype)
127
128


zlsh80826's avatar
zlsh80826 committed
129
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
130
    """
zlsh80826's avatar
zlsh80826 committed
131
    TE customcall dot product attention implementation
132
    """
133
    attn_mask_type = kwargs['attn_mask_type']
134
    if is_causal_mask(attn_mask_type):
zlsh80826's avatar
zlsh80826 committed
135
        mask = make_decoder_mask(q_token, kv_token)
136
137
138
139
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
zlsh80826's avatar
zlsh80826 committed
140
141
142
143
144
145
146
147
148
149
150
151
152
    mask = jnp.logical_not(mask)

    qkv_layout = kwargs.pop('qkv_layout')
    match qkv_layout:
        case QKVLayout.BS3HD:
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
            return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
        case QKVLayout.BSHD_BS2HD:
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
            return cross_fused_attn(query, kv, bias, mask, dropout_rng,
                                    **kwargs).astype(query.dtype)
153
154
155
        case QKVLayout.BSHD_BSHD_BSHD:
            return fused_attn(query, key, value, bias, mask, dropout_rng,
                              **kwargs).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
156
157


158
159
160
161
162
163
164
class BiasShape(Enum):
    BIAS_1HSS = '1HSS'
    BIAS_B1SS = 'B1SS'
    BIAS_BHSS = 'BHSS'
    BIAS_11SS = '11SS'


zlsh80826's avatar
zlsh80826 committed
165
166
@dataclass
class FusedAttnRunner:
167
    """
zlsh80826's avatar
zlsh80826 committed
168
    Fused attention runner
169
    """
zlsh80826's avatar
zlsh80826 committed
170
171
172
173
174
175
176
177
178
179
180
181
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
182
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
183
184
185
186
187
188
189
190

    def _check_configs(self):
        if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
            pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")

        if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
            pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

191
192
193
194
195
        self.backend = FusedAttnHelper(
            self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value,
            self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
            self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend()
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
196
            pytest.skip("Unsupported inputs combination or device compute capability.")
197

198
199
200
201
202
203
204
205
206
207
        if self.bias_shape != BiasShape.BIAS_1HSS:
            if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS:
                pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
            elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
                pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
                            "the F16_arbitrary_seqlen backend.")

zlsh80826's avatar
zlsh80826 committed
208
209
210
211
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
212

zlsh80826's avatar
zlsh80826 committed
213
214
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
215

216
217
218
219
220
221
222
223
224
225
        if self.bias_shape == BiasShape.BIAS_1HSS:
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
            bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
            pytest.xfail("PyTest attempted to use an unrecognized bias layout!")
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
                cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
                    self.bias = \
                        self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
        else:
            self.bias = None
247

zlsh80826's avatar
zlsh80826 committed
248
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
249
250
251
252
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
253
254
255
256
257
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return valid_len, tokens
258

zlsh80826's avatar
zlsh80826 committed
259
260
        self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
        self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
261

zlsh80826's avatar
zlsh80826 committed
262
263
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
        self.scaling_factor = 1. / sqrt(self.head_dim)
264

zlsh80826's avatar
zlsh80826 committed
265
266
267
268
269
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
270

zlsh80826's avatar
zlsh80826 committed
271
272
273
274
275
276
277
278
279
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': self.attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
        }
280

zlsh80826's avatar
zlsh80826 committed
281
282
283
        # Convert the outputs to float32 for the elementwise comparison
        primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
        reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
284

zlsh80826's avatar
zlsh80826 committed
285
        if self.is_training and self.dropout_prob > 0.:
286
287
            return

288
289
290
291
292
        primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
        reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
293

zlsh80826's avatar
zlsh80826 committed
294
    def test_backward(self):
295
        """
zlsh80826's avatar
zlsh80826 committed
296
        Test value_and_grad with JIT, which includes both forward and backward
297
        """
zlsh80826's avatar
zlsh80826 committed
298
299
300
301

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
302
            # Gradient is small, use a gradient multiplier to amplify the gradient
zlsh80826's avatar
zlsh80826 committed
303
304
305
            gradient_multiplier = self.valid_len_q * self.num_heads_q
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
306
            # Keep only valid result for the gradient
zlsh80826's avatar
zlsh80826 committed
307
308
            ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
309

zlsh80826's avatar
zlsh80826 committed
310
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
311
312
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
zlsh80826's avatar
zlsh80826 committed
313
            'attn_mask_type': self.attn_mask_type,
314
            'scaling_factor': self.scaling_factor,
zlsh80826's avatar
zlsh80826 committed
315
316
317
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
318
319
        }

320
321
322
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

323
324
325
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
zlsh80826's avatar
zlsh80826 committed
326
                lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
327
                                                       **kwargs), arg_nums))
328
329
        jitted_reference = jit(
            value_and_grad(
330
331
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args,
                                                       **kwargs), arg_nums))
332

zlsh80826's avatar
zlsh80826 committed
333
334
        primitive_out, primitive_dgrad = jitted_primitive(*args)
        reference_out, reference_dgrad = jitted_reference(*args)
335

zlsh80826's avatar
zlsh80826 committed
336
        # Skip elementwise comparison when dropout enabled
337
        if self.dropout_prob > 0.0:
338
339
            return

340
341
342
        assert_allclose(primitive_out.astype(jnp.float32),
                        reference_out.astype(jnp.float32),
                        dtype=self.dtype)
343

zlsh80826's avatar
zlsh80826 committed
344
345
346
        def check_dqkv(primitive, reference, valid_len):
            primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
            reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
347

348
349
350
351
352
353
354
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

        # Convert the outputs to float32 for the elementwise comparison
        primitive_dq, primitive_dk, primitive_dv = map(jnp.float32, primitive_dgrad[:3])
        reference_dq, reference_dk, reference_dv = map(jnp.float32, reference_dgrad[:3])
355

zlsh80826's avatar
zlsh80826 committed
356
357
358
        check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
        check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
        check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
359

360
361
362
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
            primitive_dbias = jnp.float32(primitive_dgrad[3])
            reference_dbias = jnp.float32(reference_dgrad[3])
363

364
            assert_allclose(
zlsh80826's avatar
zlsh80826 committed
365
                primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
366
367
                jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]),
                dtype=self.dtype)
368

369
370
371
372
            # dbias padded part
            assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                            dtype=self.dtype)
373

374
375
376
377
378
379
380
381
382
383
384
            # dbias valid part
            assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                            dtype=self.dtype)

@pytest.mark.parametrize('bias_shape', [
    pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'),
    pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'),
    pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'),
    pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'),
])
zlsh80826's avatar
zlsh80826 committed
385
386
387
388
389
390
391
392
393
394
395
@pytest.mark.parametrize('attn_bias_type', [
    pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
])
@pytest.mark.parametrize('attn_mask_type', [
    pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
    pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
    pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
    pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
    pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
    pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
    pytest.param(jnp.bfloat16, id="BF16"),
    pytest.param(jnp.float16, id="FP16")
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[
    pytest.param(32,  128,  128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
    pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
    pytest.param(32,  512,  128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
    pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
    pytest.param(32,  128,  128, 16,  8, 64, id='32-128-128-16-8-64-GQA'),
    pytest.param( 4, 2048, 2048, 12,  6, 64, id='4-2048-2048-12-6-64-GQA')
])
@pytest.mark.parametrize('dropout_prob', [
    pytest.param(0.0, id="DROP_0.0"),
    pytest.param(0.1, id="DROP_0.1")
])
@pytest.mark.parametrize('is_training', [
    pytest.param(True, id='TRAINING'),
    pytest.param(False, id='INFERENCE'),
zlsh80826's avatar
zlsh80826 committed
419
420
421
422
423
424
])
class TestFusedAttn:
    """
    Fused attention tester
    """
    @staticmethod
425
426
    def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
                     dropout_prob, dtype, is_training, qkv_layout, bias_shape):
427
        """
zlsh80826's avatar
zlsh80826 committed
428
        Test forward with parameterized configs
429
        """
zlsh80826's avatar
zlsh80826 committed
430
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
431
                                 dropout_prob, dtype, is_training, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
432
        runner.test_forward()
433

zlsh80826's avatar
zlsh80826 committed
434
    @staticmethod
435
436
    def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
                      dropout_prob, dtype, is_training, qkv_layout, bias_shape):
zlsh80826's avatar
zlsh80826 committed
437
438
439
        """
        Test backward with parameterized configs
        """
440
441
        if not is_training:
            pytest.skip("Backward pass does not support inference.")
zlsh80826's avatar
zlsh80826 committed
442
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
443
                                 dropout_prob, dtype, True, qkv_layout, bias_shape)
zlsh80826's avatar
zlsh80826 committed
444
        runner.test_backward()