test_fused_attn.py 28.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
22
from jax.typing import ArrayLike, DTypeLike
23

24
25
26
27
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
28
    QKVFormat,
29
    fused_attn,
30
31
    fused_attn_thd,
    get_qkv_format,
32
    make_swa_mask,
33
)
34
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
35
36
37
38
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
39
40

from utils import assert_allclose
41

42

43
@pytest.fixture(autouse=True, scope="module")
44
def init():
45
    """
46
    WAR for CUDA uninitialize error
47
    """
zlsh80826's avatar
zlsh80826 committed
48
49
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
50
51
52
    yield


53
54
55
56
57
58
59
60
61
62
63
64
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
65
    """
zlsh80826's avatar
zlsh80826 committed
66
    Similar to flax.linen.dot_product_attention but with GQA support
67
    """
zlsh80826's avatar
zlsh80826 committed
68
69
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
70

zlsh80826's avatar
zlsh80826 committed
71
    b, s_q, h_q, d = query.shape
72
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
73
74
75
76
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
77
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
78
79

    if bias is not None:
80
81
82
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
83
        logits = logits + bias
84
85
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
86
87
88
89

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
90
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
91

zlsh80826's avatar
zlsh80826 committed
92
    softmax_out = jax.nn.softmax(logits).astype(dtype)
93

94
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
95
96
97
98
99
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

100
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
101
102
    context = jnp.reshape(context, query.shape)
    return context
103
104


105
106
107
108
109
110
111
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


112
def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
113
    """
114
115
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
116
    """
zlsh80826's avatar
zlsh80826 committed
117
118
    q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape)
    kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape)
119
    inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal)
120
    return inv_causal_mask
121

122

123
124
125
126
127
128
def make_mask(
    q_token: ArrayLike,
    kv_token: ArrayLike,
    segment_pad_q: ArrayLike,
    segment_pad_kv: ArrayLike,
    attn_mask_type: AttnMaskType,
129
    window_size: Optional[Tuple[int, int]] = None,
130
) -> Array:
131
132
133
134
135
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
    """
136
137
138
    inv_mask = make_attention_mask(
        q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
    )
139
    if is_causal_mask(attn_mask_type):
140
141
142
143
144
145
146
        inv_causal_mask = make_causal_mask(q_token, kv_token)
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
    if segment_pad_q is not None and segment_pad_kv is not None:
        inv_pad_mask = make_attention_mask(
            segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
        )
        inv_mask = combine_masks(inv_pad_mask, inv_mask)
147
148
149
150
151
152
153
154
155

    if window_size is not None:
        max_seqlen_q = inv_mask.shape[-2]
        max_seqlen_kv = inv_mask.shape[-1]
        inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
        inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
        # In inv_swa_mask and inv_mask 0 is masked out
        inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)

156
157
    mask = jnp.logical_not(inv_mask)
    return mask
158

159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def get_seqlens_and_offsets(segment_ids, segment_pad):
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
        first_column = jnp.ones((x.shape[0], 1), dtype=bool)
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
    if segment_pad is not None:
        segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids)
        padding_aware_seqlen = bincount_vmap(segment_id_with_paddings)
        output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1)
    else:
        output = jnp.insert(seqlens, -1, values=0, axis=-1)
    return output, offsets


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
196
    """
zlsh80826's avatar
zlsh80826 committed
197
    JAX native dot product attention implementation
198
    """
199
200
201
202
203
204
205
206
207
208
209
210
    output = general_dot_product_attention(
        query,
        key,
        value,
        bias=bias,
        mask=mask,
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
211
    return output.astype(query.dtype)
212
213


214
215
216
217
218
219
220
221
222
223
224
225
226
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
227
    """
zlsh80826's avatar
zlsh80826 committed
228
    TE customcall dot product attention implementation
229
    """
230
231
    qkv_layout = kwargs["qkv_layout"]
    is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD
zlsh80826's avatar
zlsh80826 committed
232
    match qkv_layout:
233
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
234
235
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
236
237
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
238
239
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
240
241
242
243
244
245
246
247
248
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
    if not is_thd:
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
249
250
251
252
253
254
255
256
        qkv_args,
        bias,
        seqlens_q,
        seqlens_kv,
        offsets_q,
        offsets_kv,
        dropout_rng,
        **kwargs,
257
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
258
259


260
class BiasShape(Enum):
261
262
263
264
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

265
266
267
268
    BIAS_1HSS = "1HSS"
    BIAS_B1SS = "B1SS"
    BIAS_BHSS = "BHSS"
    BIAS_11SS = "11SS"
269
270


zlsh80826's avatar
zlsh80826 committed
271
272
@dataclass
class FusedAttnRunner:
273
    """
zlsh80826's avatar
zlsh80826 committed
274
    Fused attention runner
275
    """
276

zlsh80826's avatar
zlsh80826 committed
277
278
279
280
281
282
283
284
285
286
287
288
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
289
    bias_shape: BiasShape
290
    window_size: Optional[Tuple[int, int]] = None
zlsh80826's avatar
zlsh80826 committed
291

292
293
294
295
296
297
298
299
300
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
        if 90400 <= get_cudnn_version() < 90500:
            return self.num_segments_per_seq
        else:
            # +1 for testing runtime_segments < max_segments
            return self.num_segments_per_seq + 1

zlsh80826's avatar
zlsh80826 committed
301
    def _check_configs(self):
302
303
304
305
306
307
308
        # TODO(rewang): probably adds this in is_fused_attn_available
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [
            AttnMaskType.PADDING_MASK,
            AttnMaskType.PADDING_CAUSAL_MASK,
        ]:
            pytest.skip("THD format requires padding masks.")

Reese Wang's avatar
Reese Wang committed
309
310
        qkv_format = get_qkv_format(self.qkv_layout)
        if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD:
311
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
312
313
314
315
316
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")

        if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD:
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
317

318
319
320
321
322
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

323
324
325
326
327
328
329
330
331
332
333
334
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
335
            (-1, -1) if self.window_size is None else self.window_size,
336
        ).get_fused_attn_backend()
337
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
338
            pytest.skip("Unsupported inputs combination or device compute capability.")
339

340
341
342
343
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
            and self.bias_shape != BiasShape.BIAS_1HSS
        ):
344
345
346
347
            if self.attn_mask_type not in [
                AttnMaskType.NO_MASK,
                AttnMaskType.CAUSAL_MASK,
            ]:
348
349
350
351
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
                )
352
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
353
354
355
356
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
357

zlsh80826's avatar
zlsh80826 committed
358
359
360
361
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
362

zlsh80826's avatar
zlsh80826 committed
363
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
364
365
366
367
368
369
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
370

371
372
373
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
        elif self.bias_shape == BiasShape.BIAS_1HSS:
374
375
376
377
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_B1SS:
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
        elif self.bias_shape == BiasShape.BIAS_BHSS:
378
379
380
381
382
383
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
384
385
386
        elif self.bias_shape == BiasShape.BIAS_11SS:
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
387
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
388

389
390
391
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
392
393
394

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            if self.bias_shape == BiasShape.BIAS_1HSS:
395
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
396
397
398
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
399
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
400
401
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
402
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
403
404
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
405
406
407
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
408
409
        else:
            self.bias = None
410

zlsh80826's avatar
zlsh80826 committed
411
        if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
412
413
414
415
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

zlsh80826's avatar
zlsh80826 committed
416
417
418
419
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
            batch_size, sequence_length, num_segments, seed, with_segment_pad=True
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

                for _ in range(num_segments):
                    segment_size = rng.integers(1, max_segment_size + 1)
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
                    if with_segment_pad:
                        num_valid = rng.integers(1, segment_size + 1)
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1
            return segment_ids, segment_pad

        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
452
            self.num_segments_per_seq = 2
453
454
455
456
457
458
459
460
461
462
            self.token_q, self.segment_pad_q = generate_random_segment_ids(
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
            # TODO(rewang): Check if qkvpacked supported different q/kv
            # TODO(rewang): Causal with different q/kv segment_id fails
            if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
                self.token_kv = self.token_q
                self.segment_pad_kv = self.segment_pad_q
            else:
                self.token_kv, self.segment_pad_kv = generate_random_segment_ids(
463
464
465
466
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
                )
            self.pad_q = self.segment_pad_q
            self.pad_kv = self.segment_pad_kv
        else:
            self.num_segments_per_seq = 1
            self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio)
            self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
            self.segment_pad_q = self.segment_pad_kv = None

        self.mask = make_mask(
            self.token_q,
            self.token_kv,
            self.segment_pad_q,
            self.segment_pad_kv,
            self.attn_mask_type,
482
            self.window_size,
483
        )
484

485
486
487
488
489
490
491
492
493
494
495
        if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(
                self.token_q, self.segment_pad_q
            )
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(
                self.token_kv, self.segment_pad_kv
            )
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
            self.mask_for_customcall = self.mask
496

zlsh80826's avatar
zlsh80826 committed
497
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
498
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
499

zlsh80826's avatar
zlsh80826 committed
500
501
502
503
504
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
505

506
507
508
509
510
511
512
513
514
515
516
517
518
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
zlsh80826's avatar
zlsh80826 committed
519
        kwargs = {
520
521
522
523
524
525
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
526
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
527
            "window_size": self.window_size,
zlsh80826's avatar
zlsh80826 committed
528
        }
529

zlsh80826's avatar
zlsh80826 committed
530
        # Convert the outputs to float32 for the elementwise comparison
531
532
        primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
        reference_out = jax_dpa(*args, **kwargs)
533

534
        if self.is_training and self.dropout_prob > 0.0:
535
536
            return

537
538
539
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
540
541
542

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
543

zlsh80826's avatar
zlsh80826 committed
544
    def test_backward(self):
545
        """
zlsh80826's avatar
zlsh80826 committed
546
        Test value_and_grad with JIT, which includes both forward and backward
547
        """
zlsh80826's avatar
zlsh80826 committed
548
549

        self._setup_inputs()
550
551
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS:
            pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.")
zlsh80826's avatar
zlsh80826 committed
552
553

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
554
            # Gradient is small, use a gradient multiplier to amplify the gradient
555
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
zlsh80826's avatar
zlsh80826 committed
556
557
            if is_causal_mask(self.attn_mask_type):
                gradient_multiplier /= 10
558
            # Keep only valid result for the gradient
559
560
561
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
zlsh80826's avatar
zlsh80826 committed
562
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
563

564
565
566
567
568
569
570
571
572
573
574
575
576
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
577
        kwargs = {
578
579
580
581
582
583
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
584
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
585
            "window_size": self.window_size,
586
587
        }

588
589
590
        # We can compute dBias only for the [1, h, s, s] layout
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2)

591
592
593
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
594
595
596
597
598
599
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
600
601
        jitted_reference = jit(
            value_and_grad(
602
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
603
604
605
                arg_nums,
            )
        )
606

607
        primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
zlsh80826's avatar
zlsh80826 committed
608
        reference_out, reference_dgrad = jitted_reference(*args)
609

zlsh80826's avatar
zlsh80826 committed
610
        # Skip elementwise comparison when dropout enabled
611
        if self.dropout_prob > 0.0:
612
613
            return

614
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
615

616
617
618
619
        def check_dqkv(primitive, reference, pad):
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
620

621
622
623
624
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

625
626
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
627

628
629
630
        check_dqkv(primitive_dq, reference_dq, self.pad_q)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv)
631

632
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS:
633
634
635
636
637
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
638

639
            # Assert all masked dbias are 0s
640
            assert_allclose(
641
642
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
643
644
                dtype=self.dtype,
            )
645

646
            # dbias padded part
647
            assert_allclose(
648
649
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
650
651
                dtype=self.dtype,
            )
652

653
            # dbias valid part
654
            assert_allclose(
655
656
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_bias_type, bias_shape",
    [
        pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
        pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
    ],
)
@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
686
687
688
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
689
690
691
    ],
)
@pytest.mark.parametrize(
692
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
693
    [
694
695
696
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
697
        pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
698
699
700
701
702
703
704
705
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
706
            id="2-2048-1024-12-12-64-BF16-CROSS",
707
        ),
708
709
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
710
711
712
713
714
715
716
717
718
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
719
720
721
722
723
724
725
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
726
727
728
729
class TestFusedAttn:
    """
    Fused attention tester
    """
730

zlsh80826's avatar
zlsh80826 committed
731
    @staticmethod
732
733
734
735
736
737
738
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
739
    def _test_forward(
740
741
742
743
744
745
746
747
748
749
750
751
752
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
753
        swa,
754
    ):
755
        """
zlsh80826's avatar
zlsh80826 committed
756
        Test forward with parameterized configs
757
758
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
759
        """
760
761
762
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
763
764
765
766
767
768
769
770
771
772
773
774
775
776
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
777
            window_size,
778
        )
zlsh80826's avatar
zlsh80826 committed
779
        runner.test_forward()
780

zlsh80826's avatar
zlsh80826 committed
781
    @staticmethod
782
783
784
785
786
787
788
789
790
791
792
793
794
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
795
        swa,
796
    ):
zlsh80826's avatar
zlsh80826 committed
797
798
799
        """
        Test backward with parameterized configs
        """
800
801
802
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
817
            window_size,
818
        )
zlsh80826's avatar
zlsh80826 committed
819
        runner.test_backward()