test_fused_attn.py 29.1 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5
from enum import Enum
zlsh80826's avatar
zlsh80826 committed
6
7
from dataclasses import dataclass
from functools import partial
8
from math import sqrt
9
from typing import Tuple, Optional
10
import random
11
12
13

import jax
import jax.numpy as jnp
14
import numpy as np
15
16
17
18
import pytest

from flax.linen import combine_masks
from flax.linen import make_attention_mask
zlsh80826's avatar
zlsh80826 committed
19
20
from flax.linen.dtypes import promote_dtype
from jax import Array
21
from jax import value_and_grad, jit
zlsh80826's avatar
zlsh80826 committed
22
from jax.typing import ArrayLike, DTypeLike
23

24
25
26
27
from transformer_engine.jax.attention import (
    AttnBiasType,
    AttnMaskType,
    QKVLayout,
28
    QKVFormat,
29
    fused_attn,
30
    fused_attn_thd,
31
    make_swa_mask,
32
)
33
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
34
35
36
37
from transformer_engine.transformer_engine_jax import (
    NVTE_Fused_Attn_Backend,
    get_cudnn_version,
)
38
39

from utils import assert_allclose
40

41

42
@pytest.fixture(autouse=True, scope="module")
43
def init():
44
    """
45
    WAR for CUDA uninitialize error
46
    """
zlsh80826's avatar
zlsh80826 committed
47
48
    # Calling customcalls before jax may cause CUDA uninitialize error
    _ = jnp.zeros(0)
49
50
51
    yield


52
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
53
54
55
56
57
58
59
60
61
62
63
64
def general_dot_product_attention(
    query: ArrayLike,
    key: ArrayLike,
    value: ArrayLike,
    bias: ArrayLike,
    mask: ArrayLike,
    deterministic: bool,
    scale_factor: float,
    dropout_rate: float,
    dropout_rng: ArrayLike,
    dtype: DTypeLike,
) -> Array:
65
    """
zlsh80826's avatar
zlsh80826 committed
66
    Similar to flax.linen.dot_product_attention but with GQA support
67
    """
zlsh80826's avatar
zlsh80826 committed
68
69
    query, key, value, bias = promote_dtype(query, key, value, bias, dtype=dtype)
    dtype = query.dtype
70

zlsh80826's avatar
zlsh80826 committed
71
    b, s_q, h_q, d = query.shape
72
    _, s_kv, h_kv, _ = key.shape
zlsh80826's avatar
zlsh80826 committed
73
74
75
76
    assert (h_q % h_kv == 0) and (h_q >= h_kv)
    num_groups = h_q // h_kv
    grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
    # logits with shape (b, h_kv, num_groups, s_q, s_kv)
77
    logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
zlsh80826's avatar
zlsh80826 committed
78
79

    if bias is not None:
80
81
82
        # reshape logits without groups
        logits = logits.reshape((b, h_kv * num_groups, s_q, s_kv))
        # apply post-scale bias
zlsh80826's avatar
zlsh80826 committed
83
        logits = logits + bias
84
85
        # reshape logits back to original
        logits = logits.reshape((b, h_kv, num_groups, s_q, s_kv))
zlsh80826's avatar
zlsh80826 committed
86
87
88
89

    if mask is not None:
        if mask.ndim != logits.ndim:
            mask = jnp.expand_dims(mask, axis=-3)
90
        logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
91

zlsh80826's avatar
zlsh80826 committed
92
    softmax_out = jax.nn.softmax(logits).astype(dtype)
93

94
    if not deterministic and dropout_rate > 0.0:
zlsh80826's avatar
zlsh80826 committed
95
96
97
98
99
        keep_prob = 1.0 - dropout_rate
        keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
        multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
        softmax_out = softmax_out * multiplier

100
    context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
zlsh80826's avatar
zlsh80826 committed
101
102
    context = jnp.reshape(context, query.shape)
    return context
103
104


105
106
107
108
109
110
111
@jax.jit
def make_causal_mask(
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike = None,
    segment_pos_kv: ArrayLike = None,
) -> Array:
112
    """
113
114
    Create inverse padded causal mask where `True` means allowing the corresponding
    position to participate in attention and `False` means masking out that position.
115
    If segment_pos is not provided, aragne of the segment_ids will be applied.
116
    """
117
118
119
120
121
122
123
124
125
    if segment_pos_q is None:
        segment_pos_q = jnp.broadcast_to(
            jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
        )
    if segment_pos_kv is None:
        segment_pos_kv = jnp.broadcast_to(
            jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
        )
    inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal)
126
    return inv_causal_mask
127

128

129
@partial(jax.jit, static_argnums=(4, 5))
130
def make_mask(
131
132
133
134
    segment_ids_q: ArrayLike,
    segment_ids_kv: ArrayLike,
    segment_pos_q: ArrayLike,
    segment_pos_kv: ArrayLike,
135
    attn_mask_type: AttnMaskType,
136
    window_size: Optional[Tuple[int, int]] = None,
137
) -> Array:
138
139
140
141
    """
    Create attention mask based on mask type. A `True` value in the mask means
    masking out the corresponding position and a `False` value means allowing
    that position to participate in attention.
142
143
144
145
146
147
148
149

    - segment_ids should start with 1, and using 0s for the paddings.
      Expected that each segment starts without paddings.
    - segment_pos marks the token position in the segments.

    A example pair of segments_ids and segment_pos:
    segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
    segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
150
    """
151
    inv_mask = make_attention_mask(
152
        segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
153
    )
154
155
156
157
158
159
160
161
162
163
164
    if attn_mask_type.is_causal():
        if segment_pos_q is None:
            segment_pos_q = jnp.broadcast_to(
                jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
            )
        if segment_pos_kv is None:
            segment_pos_kv = jnp.broadcast_to(
                jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
            )
        inv_causal_mask = make_attention_mask(
            segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
165
        )
166
        inv_mask = combine_masks(inv_causal_mask, inv_mask)
167
168
169
170
171
172
173
174
175

    if window_size is not None:
        max_seqlen_q = inv_mask.shape[-2]
        max_seqlen_kv = inv_mask.shape[-1]
        inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
        inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
        # In inv_swa_mask and inv_mask 0 is masked out
        inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)

176
177
    mask = jnp.logical_not(inv_mask)
    return mask
178

179

180
181
@jax.jit
def get_seqlens_and_offsets(segment_ids):
182
183
184
185
186
187
188
    batch, max_seqlen = segment_ids.shape
    bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen))
    seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32))
    seqlens = seqlens_with_zero[..., 1:]

    def _find_offsets(x):
        same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0)
189
        first_column = x[..., :1] != 0
190
191
192
193
194
195
196
        same_as_previous = jnp.hstack((first_column, same_as_previous))
        return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))(
            same_as_previous
        ).squeeze(-1)

    offsets = _find_offsets(segment_ids)
    offsets = jnp.insert(offsets, -1, values=-1, axis=-1)
197
198
199
    seqlens = jnp.insert(seqlens, -1, values=0, axis=-1)
    seqlens = jnp.where(seqlens, seqlens, -1)
    return seqlens, offsets
200
201
202
203
204
205
206
207
208
209
210
211
212


@jax.jit
def _split_valid_and_invalid(primitive, reference, pad):
    """Use JIT to speed up the verifications"""
    primitive_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, primitive)
    primitive_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], primitive, 0)
    reference_valid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], 0, reference)
    reference_invalid = jnp.where(pad[..., jnp.newaxis, jnp.newaxis], reference, 0)
    return primitive_valid, primitive_invalid, reference_valid, reference_invalid


def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
213
    """
zlsh80826's avatar
zlsh80826 committed
214
    JAX native dot product attention implementation
215
    """
216
217
218
219
    output = general_dot_product_attention(
        query,
        key,
        value,
220
221
        bias,
        mask,
222
223
224
225
226
227
        deterministic=not kwargs["is_training"],
        scale_factor=kwargs["scaling_factor"],
        dropout_rate=kwargs["dropout_probability"],
        dropout_rng=dropout_rng,
        dtype=jnp.float32,
    )
zlsh80826's avatar
zlsh80826 committed
228
    return output.astype(query.dtype)
229
230


231
232
233
234
235
236
237
238
239
240
241
242
243
def customcall_fused_dpa(
    query,
    key,
    value,
    bias,
    mask,
    seqlens_q,
    seqlens_kv,
    offsets_q,
    offsets_kv,
    dropout_rng,
    **kwargs,
):
244
    """
zlsh80826's avatar
zlsh80826 committed
245
    TE customcall dot product attention implementation
246
    """
247
    qkv_layout = kwargs["qkv_layout"]
zlsh80826's avatar
zlsh80826 committed
248
    match qkv_layout:
249
        case QKVLayout.BS3HD | QKVLayout.T3HD:
zlsh80826's avatar
zlsh80826 committed
250
251
            query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
            qkv = jnp.concatenate((query, key, value), axis=-3)
252
253
            qkv_args = (qkv,)
        case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD:
zlsh80826's avatar
zlsh80826 committed
254
255
            key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
            kv = jnp.concatenate((key, value), axis=-3)
256
257
258
259
260
            qkv_args = (query, kv)
        case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD:
            qkv_args = (query, key, value)
        case _:
            raise ValueError(f"Unsupported {qkv_layout=}")
261
    if not qkv_layout.is_thd():
262
263
264
        kwargs.pop("max_segments_per_seq")
        return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
    return fused_attn_thd(
265
266
267
268
269
270
271
272
        qkv_args,
        bias,
        seqlens_q,
        seqlens_kv,
        offsets_q,
        offsets_kv,
        dropout_rng,
        **kwargs,
273
    ).astype(query.dtype)
zlsh80826's avatar
zlsh80826 committed
274
275


276
class BiasShape(Enum):
277
278
279
280
    """
    Enum class to represent the different bias shapes used in the fused attention.
    """

281
282
283
284
    _1HSS = "1HSS"
    _B1SS = "B1SS"
    _BHSS = "BHSS"
    _11SS = "11SS"
285
286


zlsh80826's avatar
zlsh80826 committed
287
288
@dataclass
class FusedAttnRunner:
289
    """
zlsh80826's avatar
zlsh80826 committed
290
    Fused attention runner
291
    """
292

zlsh80826's avatar
zlsh80826 committed
293
294
295
296
297
298
299
300
301
302
303
304
    batch_size: int
    max_seqlen_q: int
    max_seqlen_kv: int
    num_heads_q: int
    num_heads_kv: int
    head_dim: int
    attn_bias_type: AttnBiasType
    attn_mask_type: AttnMaskType
    dropout_prob: float
    dtype: DTypeLike
    is_training: bool
    qkv_layout: QKVLayout
305
    bias_shape: BiasShape
306
    window_size: Optional[Tuple[int, int]] = None
zlsh80826's avatar
zlsh80826 committed
307

308
309
310
311
312
313
314
315
316
    # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
    # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
    def _get_max_segments_per_sequence(self):
        if 90400 <= get_cudnn_version() < 90500:
            return self.num_segments_per_seq
        else:
            # +1 for testing runtime_segments < max_segments
            return self.num_segments_per_seq + 1

zlsh80826's avatar
zlsh80826 committed
317
    def _check_configs(self):
318
        # TODO(rewang): probably adds this in is_fused_attn_available
319
        if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
320
321
            pytest.skip("THD format requires padding masks.")

322
        if self.qkv_layout.is_qkvpacked():
323
            if self.max_seqlen_q != self.max_seqlen_kv:
Reese Wang's avatar
Reese Wang committed
324
325
326
                pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")
            if self.num_heads_q != self.num_heads_kv:
                pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")
zlsh80826's avatar
zlsh80826 committed
327

328
329
330
331
332
        if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
            pytest.skip(
                "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
            )

333
334
335
336
337
338
339
340
341
342
343
344
        self.backend = FusedAttnHelper(
            self.dtype,
            self.dtype,
            self.qkv_layout.value,
            self.attn_bias_type.value,
            self.attn_mask_type.value,
            self.dropout_prob,
            self.num_heads_q,
            self.num_heads_kv,
            self.max_seqlen_q,
            self.max_seqlen_kv,
            self.head_dim,
345
            (-1, -1) if self.window_size is None else self.window_size,
346
        ).get_fused_attn_backend()
347
        if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
zlsh80826's avatar
zlsh80826 committed
348
            pytest.skip("Unsupported inputs combination or device compute capability.")
349

350
351
        if (
            self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
352
            and self.bias_shape != BiasShape._1HSS
353
        ):
354
            if self.attn_mask_type.is_padding():
355
                pytest.skip(
356
                    "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask"
357
                )
358
            elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
359
360
361
362
                pytest.skip(
                    "B1SS, BHSS and 11SS bias shapes are only supported for "
                    "the F16_arbitrary_seqlen backend."
                )
363

zlsh80826's avatar
zlsh80826 committed
364
365
366
367
    def _setup_inputs(self):
        self._check_configs()
        key = jax.random.PRNGKey(0)
        q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
368

zlsh80826's avatar
zlsh80826 committed
369
        q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
370
371
372
373
374
375
        k_shape = v_shape = (
            self.batch_size,
            self.max_seqlen_kv,
            self.num_heads_kv,
            self.head_dim,
        )
376

377
378
        if self.attn_bias_type == AttnBiasType.NO_BIAS:
            bias_shape = None
379
        elif self.bias_shape == BiasShape._1HSS:
380
            bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
381
        elif self.bias_shape == BiasShape._B1SS:
382
            bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
383
        elif self.bias_shape == BiasShape._BHSS:
384
385
386
387
388
389
            bias_shape = (
                self.batch_size,
                self.num_heads_q,
                self.max_seqlen_q,
                self.max_seqlen_kv,
            )
390
        elif self.bias_shape == BiasShape._11SS:
391
392
            bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
        else:
393
            pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
394

395
396
397
        self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
        self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
        self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
398
399

        if self.attn_bias_type != AttnBiasType.NO_BIAS:
400
            if self.bias_shape == BiasShape._1HSS:
401
                self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
402
403
404
            else:
                # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
                # an arbitrary mask where (True/False -> 0/-Inf)
405
                cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
406
407
                self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
                max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
408
                seq_id_size = max_id * 5 // 128  # 5 ids per interval of 128 sequences
409
410
                seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
                for i in range(1, len(seq_id)):
411
412
413
                    self.bias = self.bias.at[
                        :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
                    ].set(0.0)
414
415
        else:
            self.bias = None
416

417
        if self.attn_mask_type.is_padding():
418
            pad_ratio = 0.3
419
420
        else:
            pad_ratio = 0.0
421

zlsh80826's avatar
zlsh80826 committed
422
423
424
425
        def gen_valid(bs, max_seqlen, pad_ratio):
            pad_len = int(max_seqlen * pad_ratio)
            valid_len = max_seqlen - pad_len
            tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
426
427
428
429
430
431
432
433
            return tokens, jnp.logical_not(tokens)

        def generate_random_segment_ids(
            batch_size, sequence_length, num_segments, seed, with_segment_pad=True
        ):
            rng = np.random.default_rng(seed=seed)
            # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
            segment_ids = np.zeros((batch_size, sequence_length), dtype=int)
434
435
            segment_pos = np.zeros((batch_size, sequence_length), dtype=int)
            # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0]
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad
            segment_pad = np.zeros((batch_size, sequence_length), dtype=int)

            # Not include paddings
            max_segment_size = sequence_length // num_segments
            for i in range(batch_size):
                current_pos = 0
                segment_id = 1

                for _ in range(num_segments):
                    segment_size = rng.integers(1, max_segment_size + 1)
                    if current_pos + segment_size > sequence_length:
                        break
                    segment_end = current_pos + segment_size
                    segment_ids[i, current_pos:segment_end] = segment_id
451
                    segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
452
453
454
455
456
457
458
                    if with_segment_pad:
                        num_valid = rng.integers(1, segment_size + 1)
                        segment_pad[i, current_pos + num_valid : segment_end] = 1
                    current_pos = segment_end
                    segment_id += 1
                segment_pad[i, current_pos:sequence_length] = 1

459
460
461
462
463
464
465
            segment_ids, segment_pos, segment_pad = map(
                jnp.asarray, [segment_ids, segment_pos, segment_pad]
            )
            segment_ids = jnp.where(segment_pad, 0, segment_ids)
            return segment_ids, segment_pos, segment_pad

        if self.qkv_layout.is_thd():
466
            self.num_segments_per_seq = 2
467
            self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
468
469
                self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
            )
470
471
472
473
            if self.qkv_layout == QKVLayout.T3HD:
                self.segment_ids_kv = self.segment_ids_q
                self.segment_pos_kv = self.segment_pos_q
                self.pad_kv = self.pad_q
474
            else:
475
                self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
476
477
478
479
                    self.batch_size,
                    self.max_seqlen_kv,
                    self.num_segments_per_seq,
                    seed=2024,
480
                )
481
482
            self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
            self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
483
484
        else:
            self.num_segments_per_seq = 1
485
486
487
488
489
490
491
492
            self.segment_ids_q, self.pad_q = gen_valid(
                self.batch_size, self.max_seqlen_q, pad_ratio
            )
            self.segment_ids_kv, self.pad_kv = gen_valid(
                self.batch_size, self.max_seqlen_kv, pad_ratio
            )
            self.segment_pos_q = self.segment_pos_kv = None
            self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None
493

494
        # For reference code
495
        self.mask = make_mask(
496
497
498
499
            self.segment_ids_q,
            self.segment_ids_kv,
            self.segment_pos_q,
            self.segment_pos_kv,
500
            self.attn_mask_type,
501
            self.window_size,
502
        )
503

504
        if self.qkv_layout.is_thd():
505
506
507
            self.mask_for_customcall = None  # THD format doesn't support mask
        else:
            self.mask_for_customcall = self.mask
508

zlsh80826's avatar
zlsh80826 committed
509
        self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
510
        self.scaling_factor = 1.0 / sqrt(self.head_dim)
511

zlsh80826's avatar
zlsh80826 committed
512
513
514
515
516
    def test_forward(self):
        """
        Test forward without JIT
        """
        self._setup_inputs()
517

518
519
520
521
522
523
524
525
526
527
528
529
530
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
zlsh80826's avatar
zlsh80826 committed
531
        kwargs = {
532
533
534
535
536
537
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
538
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
539
            "window_size": self.window_size,
zlsh80826's avatar
zlsh80826 committed
540
        }
541

zlsh80826's avatar
zlsh80826 committed
542
        # Convert the outputs to float32 for the elementwise comparison
543
544
        primitive_out = customcall_fused_dpa(*customcall_args, **kwargs)
        reference_out = jax_dpa(*args, **kwargs)
545

546
        if self.is_training and self.dropout_prob > 0.0:
547
548
            return

549
550
551
        primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
            _split_valid_and_invalid(primitive_out, reference_out, self.pad_q)
        )
552
553
554

        assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
        assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
555

zlsh80826's avatar
zlsh80826 committed
556
    def test_backward(self):
557
        """
zlsh80826's avatar
zlsh80826 committed
558
        Test value_and_grad with JIT, which includes both forward and backward
559
        """
zlsh80826's avatar
zlsh80826 committed
560
561
562
563

        self._setup_inputs()

        def grad_func(func, *args, **kwargs):
hugo-syn's avatar
hugo-syn committed
564
            # Gradient is small, use a gradient multiplier to amplify the gradient
565
            gradient_multiplier = self.max_seqlen_q * self.num_heads_q
566
            if self.attn_mask_type.is_causal():
zlsh80826's avatar
zlsh80826 committed
567
                gradient_multiplier /= 10
568
            # Keep only valid result for the gradient
569
570
571
            ret_valid = jnp.where(
                self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
            )
zlsh80826's avatar
zlsh80826 committed
572
            return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
573

574
575
576
577
578
579
580
581
582
583
584
585
586
        args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
        customcall_args = [
            self.q,
            self.k,
            self.v,
            self.bias,
            self.mask_for_customcall,
            self.seqlens_q,
            self.seqlens_kv,
            self.offsets_q,
            self.offsets_kv,
            self.dropout_rng,
        ]
587
        kwargs = {
588
589
590
591
592
593
            "attn_bias_type": self.attn_bias_type,
            "attn_mask_type": self.attn_mask_type,
            "scaling_factor": self.scaling_factor,
            "dropout_probability": self.dropout_prob,
            "is_training": self.is_training,
            "qkv_layout": self.qkv_layout,
594
            "max_segments_per_seq": self._get_max_segments_per_sequence(),
595
            "window_size": self.window_size,
596
597
        }

598
        # We can compute dBias only for the [1, h, s, s] layout
599
        arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
600

601
602
603
        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
604
605
606
607
608
609
                lambda q, k, v, bias, *args: grad_func(
                    customcall_fused_dpa, q, k, v, bias, *args, **kwargs
                ),
                arg_nums,
            )
        )
610
611
        jitted_reference = jit(
            value_and_grad(
612
                lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
613
614
615
                arg_nums,
            )
        )
616

617
        primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)
zlsh80826's avatar
zlsh80826 committed
618
        reference_out, reference_dgrad = jitted_reference(*args)
619

zlsh80826's avatar
zlsh80826 committed
620
        # Skip elementwise comparison when dropout enabled
621
        if self.dropout_prob > 0.0:
622
623
            return

624
        assert_allclose(primitive_out, reference_out, dtype=self.dtype)
625

626
627
628
629
        def check_dqkv(primitive, reference, pad):
            primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
                _split_valid_and_invalid(primitive, reference, pad)
            )
630

631
632
633
634
            assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
            assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
            assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

635
636
        primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3]
        reference_dq, reference_dk, reference_dv = reference_dgrad[:3]
637

638
639
640
        check_dqkv(primitive_dq, reference_dq, self.pad_q)
        check_dqkv(primitive_dk, reference_dk, self.pad_kv)
        check_dqkv(primitive_dv, reference_dv, self.pad_kv)
641

642
        if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
643
644
645
646
647
            primitive_dbias = primitive_dgrad[3]
            reference_dbias = reference_dgrad[3]

            # Assume all batch has the same actual_seqlen, probably needs to extend the tests
            bias_mask = self.mask[0, 0]
648

649
            # Assert all masked dbias are 0s
650
            assert_allclose(
651
652
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.zeros_like(primitive_dbias),
653
654
                dtype=self.dtype,
            )
655

656
            # dbias padded part
657
            assert_allclose(
658
659
                jnp.where(bias_mask, primitive_dbias, 0),
                jnp.where(bias_mask, reference_dbias, 0),
660
661
                dtype=self.dtype,
            )
662

663
            # dbias valid part
664
            assert_allclose(
665
666
                jnp.where(bias_mask, 0, primitive_dbias),
                jnp.where(bias_mask, 0, reference_dbias),
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
                dtype=self.dtype,
            )


@pytest.mark.parametrize(
    "attn_mask_type",
    [
        pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
        pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
        pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
        pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
    ],
)
@pytest.mark.parametrize(
    "qkv_layout",
    [
        pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
        pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
        pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
686
687
688
        pytest.param(QKVLayout.T3HD, id="RAGGED_QKV_PACKED"),
        pytest.param(QKVLayout.THD_T2HD, id="RAGGED_KV_PACKED"),
        pytest.param(QKVLayout.THD_THD_THD, id="RAGGED_SEPARATE"),
689
690
691
    ],
)
@pytest.mark.parametrize(
692
    "b, s_q, s_kv, h_q, h_kv, d, dtype",
693
    [
694
695
696
        pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
        pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
        pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
697
        pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
698
699
700
701
702
703
704
705
        pytest.param(
            2,
            2048,
            1024,
            12,
            12,
            64,
            jnp.bfloat16,
706
            id="2-2048-1024-12-12-64-BF16-CROSS",
707
        ),
708
709
        pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
        pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
710
711
712
713
714
715
716
717
718
    ],
)
@pytest.mark.parametrize(
    "dropout_prob",
    [
        pytest.param(0.0, id="DROP_0.0"),
        pytest.param(0.1, id="DROP_0.1"),
    ],
)
719
720
721
722
723
724
725
@pytest.mark.parametrize(
    "swa",
    [
        pytest.param(False, id="NO_SWA"),
        pytest.param(True, id="SWA"),
    ],
)
zlsh80826's avatar
zlsh80826 committed
726
727
728
729
class TestFusedAttn:
    """
    Fused attention tester
    """
730

zlsh80826's avatar
zlsh80826 committed
731
    @staticmethod
732
733
734
735
736
737
738
    @pytest.mark.parametrize(
        "is_training",
        [
            pytest.param(True, id="TRAINING"),
            pytest.param(False, id="INFERENCE"),
        ],
    )
739
740
741
742
743
744
745
746
747
748
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"),
        ],
    )
749
    def _test_forward(
750
751
752
753
754
755
756
757
758
759
760
761
762
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        is_training,
        qkv_layout,
        bias_shape,
763
        swa,
764
    ):
765
        """
zlsh80826's avatar
zlsh80826 committed
766
        Test forward with parameterized configs
767
768
        This test is not intended to run automatically during CI as it is time-consuming
        It is kept for development and debugging
769
        """
770
771
772
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
773
774
775
776
777
778
779
780
781
782
783
784
785
786
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            is_training,
            qkv_layout,
            bias_shape,
787
            window_size,
788
        )
zlsh80826's avatar
zlsh80826 committed
789
        runner.test_forward()
790

zlsh80826's avatar
zlsh80826 committed
791
    @staticmethod
792
793
794
795
796
797
798
    @pytest.mark.parametrize(
        "attn_bias_type, bias_shape",
        [
            pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
            pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
        ],
    )
799
800
801
802
803
804
805
806
807
808
809
810
811
    def test_backward(
        b,
        s_q,
        s_kv,
        h_q,
        h_kv,
        d,
        attn_bias_type,
        attn_mask_type,
        dropout_prob,
        dtype,
        qkv_layout,
        bias_shape,
812
        swa,
813
    ):
zlsh80826's avatar
zlsh80826 committed
814
815
816
        """
        Test backward with parameterized configs
        """
817
818
819
        window_size = None
        if swa:
            window_size = (s_kv // 10, 0)
820
821
822
823
824
825
826
827
828
829
830
831
832
833
        runner = FusedAttnRunner(
            b,
            s_q,
            s_kv,
            h_q,
            h_kv,
            d,
            attn_bias_type,
            attn_mask_type,
            dropout_prob,
            dtype,
            True,
            qkv_layout,
            bias_shape,
834
            window_size,
835
        )
zlsh80826's avatar
zlsh80826 committed
836
        runner.test_backward()