test_fused_attn.py 15.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5

zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11
12
13
14
15
16

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
17
18
from flax.linen.dtypes import promote_dtype
from jax import Array
19
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
20
from jax.typing import ArrayLike, DTypeLike
21

22
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
23
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
24
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
25

26

27
28
29
30
31
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
zlsh80826's avatar
zlsh80826 committed
32
33
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
34
35
36
37
38
    yield
    for arr in jax.live_arrays():
        arr.delete()


zlsh80826's avatar
zlsh80826 committed
39
40
41
42
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
                                  bias: ArrayLike, mask: ArrayLike, deterministic: bool,
                                  dropout_rate: float, dropout_rng: ArrayLike,
                                  dtype: DTypeLike) -> Array:
43
    """
zlsh80826's avatar
zlsh80826 committed
44
    Similar to flax.linen.dot_product_attention but with GQA support
45
    """
zlsh80826's avatar
zlsh80826 committed
46
47
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
48

zlsh80826's avatar
zlsh80826 committed
49
50
    depth = query.shape[-1]
    query = query / jnp.sqrt(depth).astype(dtype)
51

zlsh80826's avatar
zlsh80826 committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    b, s_q, h_q, d = query.shape
    _, _, h_kv, _ = key.shape
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
    logits = jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)

    if bias is not None:
        if bias.ndim != logits.ndim:
            bias = bias.reshape((1, *logits.shape[1:]))
        logits = logits + bias

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
        logits = jnp.where(mask, logits, jnp.finfo(dtype).min)
69

zlsh80826's avatar
zlsh80826 committed
70
    softmax_out = jax.nn.softmax(logits).astype(dtype)
71

zlsh80826's avatar
zlsh80826 committed
72
73
74
75
76
77
78
79
80
    if not deterministic and dropout_rate > 0.:
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

    context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
    context = jnp.reshape(context, query.shape)
    return context
81
82


83
84
85
86
87
88
89
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


zlsh80826's avatar
zlsh80826 committed
90
def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
91
92
93
    """
    Create padded causal mask
    """
zlsh80826's avatar
zlsh80826 committed
94
95
96
97
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
    causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
    padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
98
99
100
    return combine_masks(causal_mask, padding_mask)


zlsh80826's avatar
zlsh80826 committed
101
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
102
    """
zlsh80826's avatar
zlsh80826 committed
103
    JAX native dot product attention implementation
104
    """
105
    attn_mask_type = kwargs['attn_mask_type']
106
    if is_causal_mask(attn_mask_type):
zlsh80826's avatar
zlsh80826 committed
107
        mask = make_decoder_mask(q_token, kv_token)
108
109
110
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

zlsh80826's avatar
zlsh80826 committed
111
112
113
114
115
116
117
118
119
120
    output = general_dot_product_attention(query,
                                           key,
                                           value,
                                           bias=bias,
                                           mask=mask,
                                           deterministic=not kwargs['is_training'],
                                           dropout_rate=kwargs['dropout_probability'],
                                           dropout_rng=dropout_rng,
                                           dtype=jnp.float32)
    return output.astype(query.dtype)
121
122


zlsh80826's avatar
zlsh80826 committed
123
def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
124
    """
zlsh80826's avatar
zlsh80826 committed
125
    TE customcall dot product attention implementation
126
    """
127
    attn_mask_type = kwargs['attn_mask_type']
128
    if is_causal_mask(attn_mask_type):
zlsh80826's avatar
zlsh80826 committed
129
        mask = make_decoder_mask(q_token, kv_token)
130
131
132
133
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
zlsh80826's avatar
zlsh80826 committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    mask = jnp.logical_not(mask)

    qkv_layout = kwargs.pop('qkv_layout')
    match qkv_layout:
        case QKVLayout.BS3HD:
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
            return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
        case QKVLayout.BSHD_BS2HD:
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
            return cross_fused_attn(query, kv, bias, mask, dropout_rng,
                                    **kwargs).astype(query.dtype)


@dataclass
class FusedAttnRunner:
151
    """
zlsh80826's avatar
zlsh80826 committed
152
    Fused attention runner
153
    """
zlsh80826's avatar
zlsh80826 committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout

    def _check_configs(self):
        if self.qkv_layout == QKVLayout.BS3HD and self.num_heads_q != self.num_heads_kv:
            pytest.skip("BS3HD layout requires num_heads_q and num_heads_kv to be equal.")

        if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
            pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

        if not is_fused_attn_kernel_available(
                self.dtype, self.dtype, self.qkv_layout, self.attn_bias_type, self.attn_mask_type,
                self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.max_seqlen_q,
                self.max_seqlen_kv, self.head_dim):
            pytest.skip("Unsupported inputs combination or device compute capability.")
179

zlsh80826's avatar
zlsh80826 committed
180
181
182
183
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
184

zlsh80826's avatar
zlsh80826 committed
185
186
187
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
        bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
188

zlsh80826's avatar
zlsh80826 committed
189
190
191
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1)
192

zlsh80826's avatar
zlsh80826 committed
193
194
        with_bias = self.attn_bias_type != AttnBiasType.NO_BIAS
        self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1) if with_bias else None
195

zlsh80826's avatar
zlsh80826 committed
196
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
197
198
199
200
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
201
202
203
204
205
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
            return valid_len, tokens
206

zlsh80826's avatar
zlsh80826 committed
207
208
        self.valid_len_q, self.token_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
        self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
209

zlsh80826's avatar
zlsh80826 committed
210
211
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
        self.scaling_factor = 1. / sqrt(self.head_dim)
212

zlsh80826's avatar
zlsh80826 committed
213
214
215
216
217
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
218

zlsh80826's avatar
zlsh80826 committed
219
220
221
222
223
224
225
226
227
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': self.attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
        }
228

zlsh80826's avatar
zlsh80826 committed
229
230
231
        # Convert the outputs to float32 for the elementwise comparison
        primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
        reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
232

zlsh80826's avatar
zlsh80826 committed
233
234
        primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
        reference_valid, _ = jnp.split(reference_out, (self.valid_len_q,), axis=1)
235

zlsh80826's avatar
zlsh80826 committed
236
237
        # Skip elementwise comparison when dropout enabled
        if self.is_training and self.dropout_prob > 0.:
238
239
            return

zlsh80826's avatar
zlsh80826 committed
240
241
        np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-2, rtol=1e-4)
        np.testing.assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
242

zlsh80826's avatar
zlsh80826 committed
243
    def test_backward(self):
244
        """
zlsh80826's avatar
zlsh80826 committed
245
        Test value_and_grad with JIT, which includes both forward and backward
246
        """
zlsh80826's avatar
zlsh80826 committed
247
248
249
250
251
252
        if not self.is_training:
            pytest.skip("Backward doesn't support inference")

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
253
            # Gradient is small, use a gradient multiplier to amplify the gradient
zlsh80826's avatar
zlsh80826 committed
254
255
256
            gradient_multiplier = self.valid_len_q * self.num_heads_q
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
257
            # Keep only valid result for the gradient
zlsh80826's avatar
zlsh80826 committed
258
259
            ret_valid, _ = jnp.split(func(*args, **kwargs), (self.valid_len_q,), axis=1)
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
260

zlsh80826's avatar
zlsh80826 committed
261
        args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
262
263
        kwargs = {
            'attn_bias_type': self.attn_bias_type,
zlsh80826's avatar
zlsh80826 committed
264
            'attn_mask_type': self.attn_mask_type,
265
            'scaling_factor': self.scaling_factor,
zlsh80826's avatar
zlsh80826 committed
266
267
268
            'dropout_probability': self.dropout_prob,
            'is_training': self.is_training,
            'qkv_layout': self.qkv_layout,
269
270
271
272
273
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
zlsh80826's avatar
zlsh80826 committed
274
275
                lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
                                                       **kwargs), (0, 1, 2, 3)))
276
277
        jitted_reference = jit(
            value_and_grad(
zlsh80826's avatar
zlsh80826 committed
278
279
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
                (0, 1, 2, 3)))
280

zlsh80826's avatar
zlsh80826 committed
281
282
        primitive_out, primitive_dgrad = jitted_primitive(*args)
        reference_out, reference_dgrad = jitted_reference(*args)
283

zlsh80826's avatar
zlsh80826 committed
284
285
        # Skip elementwise comparison when dropout enabled
        if self.dropout_prob > 0.:
286
287
            return

zlsh80826's avatar
zlsh80826 committed
288
289
290
291
        np.testing.assert_allclose(primitive_out.astype(jnp.float32),
                                   reference_out.astype(jnp.float32),
                                   atol=1e-5,
                                   rtol=1e-3)
292

zlsh80826's avatar
zlsh80826 committed
293
294
295
296
297
        # Convert the outputs to float32 for the elementwise comparison
        primitive_dq, primitive_dk, primitive_dv, primitive_dbias = map(
            jnp.float32, primitive_dgrad)
        reference_dq, reference_dk, reference_dv, reference_dbias = map(
            jnp.float32, reference_dgrad)
298

zlsh80826's avatar
zlsh80826 committed
299
300
301
        def check_dqkv(primitive, reference, valid_len):
            primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
            reference_valid, reference_invalid = jnp.split(reference, (valid_len,), axis=1)
302

zlsh80826's avatar
zlsh80826 committed
303
304
305
            np.testing.assert_allclose(primitive_valid, reference_valid, atol=1e-4, rtol=1e-3)
            assert jnp.allclose(primitive_invalid, reference_invalid)
            assert jnp.allclose(primitive_invalid, jnp.zeros_like(primitive_invalid))
306

zlsh80826's avatar
zlsh80826 committed
307
308
309
        check_dqkv(primitive_dq, reference_dq, self.valid_len_q)
        check_dqkv(primitive_dk, reference_dk, self.valid_len_kv)
        check_dqkv(primitive_dv, reference_dv, self.valid_len_kv)
310

311
312
        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            # dbias valid part
zlsh80826's avatar
zlsh80826 committed
313
314
315
316
            np.testing.assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                                       reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
                                       atol=3e-5,
                                       rtol=1e-4)
317
318

            # dbias padded part
zlsh80826's avatar
zlsh80826 committed
319
320
            np.testing.assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                                       reference_dbias[..., self.valid_len_q:, self.valid_len_kv:])
321
322

            assert jnp.allclose(
zlsh80826's avatar
zlsh80826 committed
323
324
                primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
                jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]))
325
326


zlsh80826's avatar
zlsh80826 committed
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
@pytest.mark.parametrize('attn_bias_type', [
    pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
    pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),
])
@pytest.mark.parametrize('attn_mask_type', [
    pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
    pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
    pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
    pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
    pytest.param(QKVLayout.BS3HD, id='qkvpacked'),
    pytest.param(QKVLayout.BSHD_BS2HD, id='kvpacked'),
])
@pytest.mark.parametrize('dropout_prob', [0., 0.1])
@pytest.mark.parametrize('is_training',
                         [pytest.param(True, id='training'),
                          pytest.param(False, id='inference')])
@pytest.mark.parametrize(
    'dtype', [pytest.param(jnp.bfloat16, id="BF16"),
              pytest.param(jnp.float16, id="FP16")])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',
                         [(32, 128, 128, 16, 16, 64), (4, 2048, 2048, 12, 12, 64),
                          pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-cross'),
                          pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')])
class TestFusedAttn:
    """
    Fused attention tester
    """
356

zlsh80826's avatar
zlsh80826 committed
357
358
359
    @staticmethod
    def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                     dtype, is_training, qkv_layout):
360
        """
zlsh80826's avatar
zlsh80826 committed
361
        Test forward with parameterized configs
362
        """
zlsh80826's avatar
zlsh80826 committed
363
364
365
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
                                 dropout_prob, dtype, is_training, qkv_layout)
        runner.test_forward()
366

zlsh80826's avatar
zlsh80826 committed
367
368
369
370
371
372
373
374
375
    @staticmethod
    def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
                      dtype, is_training, qkv_layout):
        """
        Test backward with parameterized configs
        """
        runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
                                 dropout_prob, dtype, is_training, qkv_layout)
        runner.test_backward()