test_fused_attn.py 26.5 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11

import jax
import jax.numpy as jnp
12
import numpy as np
13
14
15
16
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
17
18
from flax.linen.dtypes import promote_dtype
from jax import Array
19
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
20
from jax.typing import ArrayLike, DTypeLike
21

22
23
24
25
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
26
    QKVFormat,
27
    fused_attn,
28
29
    fused_attn_thd,
    get_qkv_format,
30
)
31
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
32
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
33
34

from utils import assert_allclose
35

36

37
@pytest.fixture(autouse=True, scope="module")
38
def init():
39
    """
40
    WAR for CUDA uninitialize error
41
    """
zlsh80826's avatar
zlsh80826 committed
42
43
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
44
45
46
    yield


47
48
49
50
51
52
53
54
55
56
57
58
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
59
    """
zlsh80826's avatar
zlsh80826 committed
60
    Similar to flax.linen.dot_product_attention but with GQA support
61
    """
zlsh80826's avatar
zlsh80826 committed
62
63
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
64

zlsh80826's avatar
zlsh80826 committed
65
    b, s_q, h_q, d = query.shape
66
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
67
68
69
70
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
71
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
72
73

    if bias is not None:
74
75
76
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
77
        logits = logits + bias
78
79
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
80
81
82
83

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
84
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
85

zlsh80826's avatar
zlsh80826 committed
86
    softmax_out = jax.nn.softmax(logits).astype(dtype)
87

88
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
89
90
91
92
93
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

94
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
95
96
    context = jnp.reshape(context, query.shape)
    return context
97
98


99
100
101
102
103
104
105
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


106
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
107
    """
108
109
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
110
    """
zlsh80826's avatar
zlsh80826 committed
111
112
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
113
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
114
    return inv_causal_mask
115

116

117
118
119
120
121
122
123
def make_mask(
    q_token: ArrayLike,
    kv_token: ArrayLike,
    segment_pad_q: ArrayLike,
    segment_pad_kv: ArrayLike,
    attn_mask_type: AttnMaskType,
) -> Array:
124
125
126
127
128
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
129
130
131
    inv_mask = make_attention_mask(
        q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
    )
132
    if is_causal_mask(attn_mask_type):
133
134
135
136
137
138
139
        inv_causal_mask = make_causal_mask(q_token, kv_token)
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
    if segment_pad_q is not None and segment_pad_kv is not None:
        inv_pad_mask = make_attention_mask(
            segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
        )
        inv_mask = combine_masks(inv_pad_mask, inv_mask)
140
141
    mask = jnp.logical_not(inv_mask)
    return mask
142

143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def get_seqlens_and_offsets(segment_ids, segment_pad):
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
        first_column = jnp.ones((x.shape[0], 1), dtype=bool)
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
    if segment_pad is not None:
        segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
        padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
        output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
    else:
        output = jnp.insert(seqlens, -1, values=0, axis=-1)
    return output, offsets


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
180
    """
zlsh80826's avatar
zlsh80826 committed
181
    JAX native dot product attention implementation
182
    """
183
184
185
186
187
188
189
190
191
192
193
194
    output = general_dot_product_attention(
        query,
        key,
        value,
        bias=bias,
        mask=mask,
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
195
    return output.astype(query.dtype)
196
197


198
199
200
201
202
203
204
205
206
207
208
209
210
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
211
    """
zlsh80826's avatar
zlsh80826 committed
212
    TE customcall dot product attention implementation
213
    """
214
215
    qkv_layout = kwargs["qkv_layout"]
    is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
zlsh80826's avatar
zlsh80826 committed
216
    match qkv_layout:
217
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
218
219
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
220
221
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
222
223
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
224
225
226
227
228
229
230
231
232
233
234
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    if not is_thd:
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
        qkv_args, bias, seqlens_q, seqlens_kv, offsets_q, offsets_kv, dropout_rng, **kwargs
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
235
236


237
class BiasShape(Enum):
238
239
240
241
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

242
243
244
245
    BIAS_1HSS = "1HSS"
    BIAS_B1SS = "B1SS"
    BIAS_BHSS = "BHSS"
    BIAS_11SS = "11SS"
246
247


zlsh80826's avatar
zlsh80826 committed
248
249
@dataclass
class FusedAttnRunner:
250
    """
zlsh80826's avatar
zlsh80826 committed
251
    Fused attention runner
252
    """
253

zlsh80826's avatar
zlsh80826 committed
254
255
256
257
258
259
260
261
262
263
264
265
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
266
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
267
268

    def _check_configs(self):
269
270
271
272
273
274
275
276
277
278
279
280
        # TODO(rewang): probably adds this in is_fused_attn_available
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
        ]:
            pytest.skip("THD format requires padding masks.")

        if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD:
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.")
            if self.max_seqlen_q != self.max_seqlen_kv:
                pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
zlsh80826's avatar
zlsh80826 committed
281

282
283
284
285
286
287
288
289
290
291
292
293
294
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
        ).get_fused_attn_backend()
295
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
296
            pytest.skip("Unsupported inputs combination or device compute capability.")
297

298
299
300
301
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
            and self.bias_shape != BiasShape.BIAS_1HSS
        ):
302
            if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
303
304
305
306
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
                )
307
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
308
309
310
311
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
312

zlsh80826's avatar
zlsh80826 committed
313
314
315
316
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
317

zlsh80826's avatar
zlsh80826 committed
318
319
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
        k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)
320

321
322
323
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
324
325
326
327
328
329
330
331
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
            bias_shape = (self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
332
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
333

334
335
336
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
337
338
339

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
340
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
341
342
343
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
344
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
345
346
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
347
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
348
349
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
350
351
352
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
353
354
        else:
            self.bias = None
355

zlsh80826's avatar
zlsh80826 committed
356
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
357
358
359
360
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
361
362
363
364
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
            batch_size, sequence_length, num_segments, seed, with_segment_pad=True
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

                for _ in range(num_segments):
                    segment_size = rng.integers(1, max_segment_size + 1)
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
                    if with_segment_pad:
                        num_valid = rng.integers(1, segment_size + 1)
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1
            return segment_ids, segment_pad

        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
397
            self.num_segments_per_seq = 2
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            self.token_q, self.segment_pad_q = generate_random_segment_ids(
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
            # TODO(rewang): Check if qkvpacked supported different q/kv
            # TODO(rewang): Causal with different q/kv segment_id fails
            if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
                self.token_kv = self.token_q
                self.segment_pad_kv = self.segment_pad_q
            else:
                self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
                    self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024
                )
            self.pad_q = self.segment_pad_q
            self.pad_kv = self.segment_pad_kv
        else:
            self.num_segments_per_seq = 1
            self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
            self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
            self.segment_pad_q = self.segment_pad_kv = None

        self.mask = make_mask(
            self.token_q,
            self.token_kv,
            self.segment_pad_q,
            self.segment_pad_kv,
            self.attn_mask_type,
        )
425

426
427
428
429
430
431
432
433
434
435
436
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
                self.token_q, self.segment_pad_q
            )
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
                self.token_kv, self.segment_pad_kv
            )
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
            self.mask_for_customcall = self.mask
437

zlsh80826's avatar
zlsh80826 committed
438
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
439
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
440

zlsh80826's avatar
zlsh80826 committed
441
442
443
444
445
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
446

447
448
449
450
451
452
453
454
455
456
457
458
459
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
zlsh80826's avatar
zlsh80826 committed
460
        kwargs = {
461
462
463
464
465
466
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
467
468
            # +1 for testing runtime_segments < max_segments
            "max_segments_per_seq": self.num_segments_per_seq + 1,
zlsh80826's avatar
zlsh80826 committed
469
        }
470

zlsh80826's avatar
zlsh80826 committed
471
        # Convert the outputs to float32 for the elementwise comparison
472
473
        primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
        reference_out = jax_dpa(*args, **kwargs)
474

475
        if self.is_training and self.dropout_prob > 0.0:
476
477
            return

478
479
480
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
481
482
483

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
484

zlsh80826's avatar
zlsh80826 committed
485
    def test_backward(self):
486
        """
zlsh80826's avatar
zlsh80826 committed
487
        Test value_and_grad with JIT, which includes both forward and backward
488
        """
zlsh80826's avatar
zlsh80826 committed
489
490

        self._setup_inputs()
491
492
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
493
494

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
495
            # Gradient is small, use a gradient multiplier to amplify the gradient
496
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
zlsh80826's avatar
zlsh80826 committed
497
498
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
499
            # Keep only valid result for the gradient
500
501
502
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
zlsh80826's avatar
zlsh80826 committed
503
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
504

505
506
507
508
509
510
511
512
513
514
515
516
517
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
518
        kwargs = {
519
520
521
522
523
524
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
525
            "max_segments_per_seq": self.num_segments_per_seq + 1,
526
527
        }

528
529
530
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

531
532
533
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
534
535
536
537
538
539
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
540
541
        jitted_reference = jit(
            value_and_grad(
542
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
543
544
545
                arg_nums,
            )
        )
546

547
        primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
zlsh80826's avatar
zlsh80826 committed
548
        reference_out, reference_dgrad = jitted_reference(*args)
549

zlsh80826's avatar
zlsh80826 committed
550
        # Skip elementwise comparison when dropout enabled
551
        if self.dropout_prob > 0.0:
552
553
            return

554
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
555

556
557
558
559
        def check_dqkv(primitive, reference, pad):
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
560

561
562
563
564
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

565
566
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
567

568
569
570
        check_dqkv(primitive_dq, reference_dq, self.pad_q)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv)
571

572
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
573
574
575
576
577
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
578

579
            # Assert all masked dbias are 0s
580
            assert_allclose(
581
582
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
583
584
                dtype=self.dtype,
            )
585

586
            # dbias padded part
587
            assert_allclose(
588
589
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
590
591
                dtype=self.dtype,
            )
592

593
            # dbias valid part
594
            assert_allclose(
595
596
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_bias_type, bias_shape",
    [
        pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
    ],
)
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
626
627
628
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
629
630
631
    ],
)
@pytest.mark.parametrize(
632
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
633
    [
634
635
636
637
638
639
640
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
        pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
        pytest.param(2, 2048, 1024, 12, 12, 64, jnp.bfloat16, id="2-2048-1048-12-12-64-BF16-CROSS"),
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
641
642
643
644
645
646
647
648
649
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
650
651
652
653
class TestFusedAttn:
    """
    Fused attention tester
    """
654

zlsh80826's avatar
zlsh80826 committed
655
    @staticmethod
656
657
658
659
660
661
662
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
663
    def _test_forward(
664
665
666
667
668
669
670
671
672
673
674
675
676
677
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
    ):
678
        """
zlsh80826's avatar
zlsh80826 committed
679
        Test forward with parameterized configs
680
681
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
682
        """
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
        )
zlsh80826's avatar
zlsh80826 committed
698
        runner.test_forward()
699

zlsh80826's avatar
zlsh80826 committed
700
    @staticmethod
701
702
703
704
705
706
707
708
709
710
711
712
713
714
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
    ):
zlsh80826's avatar
zlsh80826 committed
715
716
717
        """
        Test backward with parameterized configs
        """
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
        )
zlsh80826's avatar
zlsh80826 committed
733
        runner.test_backward()