test_fused_attn.py 27.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
10
11

import jax
import jax.numpy as jnp
12
import numpy as np
13
14
15
16
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
17
18
from flax.linen.dtypes import promote_dtype
from jax import Array
19
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
20
from jax.typing import ArrayLike, DTypeLike
21

22
23
24
25
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
26
    QKVFormat,
27
    fused_attn,
28
29
    fused_attn_thd,
    get_qkv_format,
30
)
31
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
32
33
34
35
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
36
37

from utils import assert_allclose
38

39

40
@pytest.fixture(autouse=True, scope="module")
41
def init():
42
    """
43
    WAR for CUDA uninitialize error
44
    """
zlsh80826's avatar
zlsh80826 committed
45
46
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
47
48
49
    yield


50
51
52
53
54
55
56
57
58
59
60
61
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
62
    """
zlsh80826's avatar
zlsh80826 committed
63
    Similar to flax.linen.dot_product_attention but with GQA support
64
    """
zlsh80826's avatar
zlsh80826 committed
65
66
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
67

zlsh80826's avatar
zlsh80826 committed
68
    b, s_q, h_q, d = query.shape
69
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
70
71
72
73
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
74
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
75
76

    if bias is not None:
77
78
79
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
80
        logits = logits + bias
81
82
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
83
84
85
86

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
87
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
88

zlsh80826's avatar
zlsh80826 committed
89
    softmax_out = jax.nn.softmax(logits).astype(dtype)
90

91
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
92
93
94
95
96
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

97
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
98
99
    context = jnp.reshape(context, query.shape)
    return context
100
101


102
103
104
105
106
107
108
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


109
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
110
    """
111
112
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
113
    """
zlsh80826's avatar
zlsh80826 committed
114
115
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
116
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
117
    return inv_causal_mask
118

119

120
121
122
123
124
125
126
def make_mask(
    q_token: ArrayLike,
    kv_token: ArrayLike,
    segment_pad_q: ArrayLike,
    segment_pad_kv: ArrayLike,
    attn_mask_type: AttnMaskType,
) -> Array:
127
128
129
130
131
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
132
133
134
    inv_mask = make_attention_mask(
        q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
    )
135
    if is_causal_mask(attn_mask_type):
136
137
138
139
140
141
142
        inv_causal_mask = make_causal_mask(q_token, kv_token)
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
    if segment_pad_q is not None and segment_pad_kv is not None:
        inv_pad_mask = make_attention_mask(
            segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
        )
        inv_mask = combine_masks(inv_pad_mask, inv_mask)
143
144
    mask = jnp.logical_not(inv_mask)
    return mask
145

146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def get_seqlens_and_offsets(segment_ids, segment_pad):
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
        first_column = jnp.ones((x.shape[0], 1), dtype=bool)
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
    if segment_pad is not None:
        segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
        padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
        output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
    else:
        output = jnp.insert(seqlens, -1, values=0, axis=-1)
    return output, offsets


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
183
    """
zlsh80826's avatar
zlsh80826 committed
184
    JAX native dot product attention implementation
185
    """
186
187
188
189
190
191
192
193
194
195
196
197
    output = general_dot_product_attention(
        query,
        key,
        value,
        bias=bias,
        mask=mask,
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
198
    return output.astype(query.dtype)
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
214
    """
zlsh80826's avatar
zlsh80826 committed
215
    TE customcall dot product attention implementation
216
    """
217
218
    qkv_layout = kwargs["qkv_layout"]
    is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
zlsh80826's avatar
zlsh80826 committed
219
    match qkv_layout:
220
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
221
222
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
223
224
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
225
226
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
227
228
229
230
231
232
233
234
235
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    if not is_thd:
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
236
237
238
239
240
241
242
243
        qkv_args,
        bias,
        seqlens_q,
        seqlens_kv,
        offsets_q,
        offsets_kv,
        dropout_rng,
        **kwargs,
244
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
245
246


247
class BiasShape(Enum):
248
249
250
251
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

252
253
254
255
    BIAS_1HSS = "1HSS"
    BIAS_B1SS = "B1SS"
    BIAS_BHSS = "BHSS"
    BIAS_11SS = "11SS"
256
257


zlsh80826's avatar
zlsh80826 committed
258
259
@dataclass
class FusedAttnRunner:
260
    """
zlsh80826's avatar
zlsh80826 committed
261
    Fused attention runner
262
    """
263

zlsh80826's avatar
zlsh80826 committed
264
265
266
267
268
269
270
271
272
273
274
275
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
276
    bias_shape: BiasShape
zlsh80826's avatar
zlsh80826 committed
277

278
279
280
281
282
283
284
285
286
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
        if 90400 <= get_cudnn_version() < 90500:
            return self.num_segments_per_seq
        else:
            # +1 for testing runtime_segments < max_segments
            return self.num_segments_per_seq + 1

zlsh80826's avatar
zlsh80826 committed
287
    def _check_configs(self):
288
289
290
291
292
293
294
295
296
297
298
299
        # TODO(rewang): probably adds this in is_fused_attn_available
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
        ]:
            pytest.skip("THD format requires padding masks.")

        if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD:
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.")
            if self.max_seqlen_q != self.max_seqlen_kv:
                pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
zlsh80826's avatar
zlsh80826 committed
300

301
302
303
304
305
306
307
308
309
310
311
312
313
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
        ).get_fused_attn_backend()
314
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
315
            pytest.skip("Unsupported inputs combination or device compute capability.")
316

317
318
319
320
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
            and self.bias_shape != BiasShape.BIAS_1HSS
        ):
321
322
323
324
            if self.attn_mask_type not in [
                AttnMaskType.NO_MASK,
                AttnMaskType.CAUSAL_MASK,
            ]:
325
326
327
328
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
                )
329
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
330
331
332
333
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
334

zlsh80826's avatar
zlsh80826 committed
335
336
337
338
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
339

zlsh80826's avatar
zlsh80826 committed
340
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
341
342
343
344
345
346
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
347

348
349
350
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
351
352
353
354
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
355
356
357
358
359
360
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
361
362
363
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
364
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
365

366
367
368
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
369
370
371

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
372
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
373
374
375
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
376
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
377
378
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
379
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
380
381
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
382
383
384
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
385
386
        else:
            self.bias = None
387

zlsh80826's avatar
zlsh80826 committed
388
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
389
390
391
392
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
393
394
395
396
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
            batch_size, sequence_length, num_segments, seed, with_segment_pad=True
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

                for _ in range(num_segments):
                    segment_size = rng.integers(1, max_segment_size + 1)
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
                    if with_segment_pad:
                        num_valid = rng.integers(1, segment_size + 1)
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1
            return segment_ids, segment_pad

        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
429
            self.num_segments_per_seq = 2
430
431
432
433
434
435
436
437
438
439
            self.token_q, self.segment_pad_q = generate_random_segment_ids(
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
            # TODO(rewang): Check if qkvpacked supported different q/kv
            # TODO(rewang): Causal with different q/kv segment_id fails
            if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
                self.token_kv = self.token_q
                self.segment_pad_kv = self.segment_pad_q
            else:
                self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
440
441
442
443
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
                )
            self.pad_q = self.segment_pad_q
            self.pad_kv = self.segment_pad_kv
        else:
            self.num_segments_per_seq = 1
            self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
            self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
            self.segment_pad_q = self.segment_pad_kv = None

        self.mask = make_mask(
            self.token_q,
            self.token_kv,
            self.segment_pad_q,
            self.segment_pad_kv,
            self.attn_mask_type,
        )
460

461
462
463
464
465
466
467
468
469
470
471
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
                self.token_q, self.segment_pad_q
            )
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
                self.token_kv, self.segment_pad_kv
            )
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
            self.mask_for_customcall = self.mask
472

zlsh80826's avatar
zlsh80826 committed
473
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
474
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
475

zlsh80826's avatar
zlsh80826 committed
476
477
478
479
480
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
481

482
483
484
485
486
487
488
489
490
491
492
493
494
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
zlsh80826's avatar
zlsh80826 committed
495
        kwargs = {
496
497
498
499
500
501
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
502
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
zlsh80826's avatar
zlsh80826 committed
503
        }
504

zlsh80826's avatar
zlsh80826 committed
505
        # Convert the outputs to float32 for the elementwise comparison
506
507
        primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
        reference_out = jax_dpa(*args, **kwargs)
508

509
        if self.is_training and self.dropout_prob > 0.0:
510
511
            return

512
513
514
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
515
516
517

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
518

zlsh80826's avatar
zlsh80826 committed
519
    def test_backward(self):
520
        """
zlsh80826's avatar
zlsh80826 committed
521
        Test value_and_grad with JIT, which includes both forward and backward
522
        """
zlsh80826's avatar
zlsh80826 committed
523
524

        self._setup_inputs()
525
526
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
527
528

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
529
            # Gradient is small, use a gradient multiplier to amplify the gradient
530
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
zlsh80826's avatar
zlsh80826 committed
531
532
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
533
            # Keep only valid result for the gradient
534
535
536
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
zlsh80826's avatar
zlsh80826 committed
537
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
538

539
540
541
542
543
544
545
546
547
548
549
550
551
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
552
        kwargs = {
553
554
555
556
557
558
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
559
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
560
561
        }

562
563
564
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

565
566
567
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
568
569
570
571
572
573
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
574
575
        jitted_reference = jit(
            value_and_grad(
576
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
577
578
579
                arg_nums,
            )
        )
580

581
        primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
zlsh80826's avatar
zlsh80826 committed
582
        reference_out, reference_dgrad = jitted_reference(*args)
583

zlsh80826's avatar
zlsh80826 committed
584
        # Skip elementwise comparison when dropout enabled
585
        if self.dropout_prob > 0.0:
586
587
            return

588
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
589

590
591
592
593
        def check_dqkv(primitive, reference, pad):
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
594

595
596
597
598
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

599
600
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
601

602
603
604
        check_dqkv(primitive_dq, reference_dq, self.pad_q)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv)
605

606
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
607
608
609
610
611
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
612

613
            # Assert all masked dbias are 0s
614
            assert_allclose(
615
616
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
617
618
                dtype=self.dtype,
            )
619

620
            # dbias padded part
621
            assert_allclose(
622
623
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
624
625
                dtype=self.dtype,
            )
626

627
            # dbias valid part
628
            assert_allclose(
629
630
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_bias_type, bias_shape",
    [
        pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
    ],
)
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
660
661
662
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
663
664
665
    ],
)
@pytest.mark.parametrize(
666
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
667
    [
668
669
670
671
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
        pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
672
673
674
675
676
677
678
679
680
681
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
            id="2-2048-1048-12-12-64-BF16-CROSS",
        ),
682
683
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
684
685
686
687
688
689
690
691
692
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
693
694
695
696
class TestFusedAttn:
    """
    Fused attention tester
    """
697

zlsh80826's avatar
zlsh80826 committed
698
    @staticmethod
699
700
701
702
703
704
705
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
706
    def _test_forward(
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
    ):
721
        """
zlsh80826's avatar
zlsh80826 committed
722
        Test forward with parameterized configs
723
724
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
725
        """
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
        )
zlsh80826's avatar
zlsh80826 committed
741
        runner.test_forward()
742

zlsh80826's avatar
zlsh80826 committed
743
    @staticmethod
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
    ):
zlsh80826's avatar
zlsh80826 committed
758
759
760
        """
        Test backward with parameterized configs
        """
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
        )
zlsh80826's avatar
zlsh80826 committed
776
        runner.test_backward()