test_fused_attn.py 24.6 KB
Newer Older
1
2
3
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5

6
7
8
import os
from enum import Enum
from math import sqrt
9
10
11
12
13
14
15
16
17
18
19
20

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from flax.linen import combine_masks
from flax.linen import dot_product_attention
from flax.linen import make_attention_mask
from flax.linen import make_causal_mask
from jax import value_and_grad, jit

21
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
22
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
23
24
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability
25
26
27
28

# Type annotations
Array = jnp.ndarray

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

class Backend(Enum):
    """
    Fused attn backend.
    Unit tests only, transformer will auto dispatch to the best backend
    """
    Max512 = "0"
    Arbitrary = "1"


@pytest.fixture(name="backend", params=[Backend.Max512, Backend.Arbitrary])
def fixture_backend(request):
    """
    Fixture of setting up/tearing down backend
    """
    backend = request.param
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
    yield backend
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""


SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64), (4, 2048, 12, 64)]
51
52
53
54
55
CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16]


def make_decoder_mask(tokens: Array) -> Array:
56
57
58
    """
    Create padded causal mask
    """
59
60
61
62
63
    causal_mask = make_causal_mask(tokens)
    padding_mask = make_attention_mask(tokens > 0, tokens > 0)
    return combine_masks(causal_mask, padding_mask)


64
65
66
67
def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
    """
    Self attention with JAX native implementation
    """
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    attn_mask_type = kwargs['attn_mask_type']
    if attn_mask_type == AttnMaskType.CAUSAL_MASK:
        mask = make_decoder_mask(q_token)
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    query, key, value = jnp.split(qkv, [1, 2], axis=-3)
    query = jnp.squeeze(query)
    key = jnp.squeeze(key)
    value = jnp.squeeze(value)

    output = dot_product_attention(query,
                                   key,
                                   value,
                                   bias=bias,
                                   mask=mask,
84
                                   deterministic=not kwargs['is_training'],
85
86
87
88
89
90
                                   dropout_rate=kwargs['dropout_probability'],
                                   dropout_rng=dropout_rng,
                                   dtype=qkv.dtype)
    return output


91
92
93
94
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
    """
    Cross attention with JAX native implementation
    """
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    assert q.dtype == kv.dtype

    attn_mask_type = kwargs['attn_mask_type']
    if attn_mask_type == AttnMaskType.CAUSAL_MASK:
        raise NotImplementedError
    mask = make_attention_mask(q_token > 0, kv_token > 0)

    query = q
    key, value = jnp.split(kv, [1], axis=-3)
    key = jnp.squeeze(key)
    value = jnp.squeeze(value)

    output = dot_product_attention(query,
                                   key,
                                   value,
                                   bias=None,
                                   mask=mask,
112
                                   deterministic=not kwargs['is_training'],
113
114
115
116
117
118
119
                                   dropout_rate=kwargs['dropout_probability'],
                                   dropout_rng=dropout_rng,
                                   dtype=q.dtype)
    return output


def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
120
121
122
    """
    Self fused attention
    """
123
124
125
126
127
128
129
130
131
132
133
134
    if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
        mask = make_decoder_mask(q_token)
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
    mask = (mask == 0)

    return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs)


def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
135
136
137
    """
    Cross fused attention
    """
138
139
140
141
142
143
144
145
146
147
148
149
    assert q.dtype == kv.dtype

    if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
        raise NotImplementedError
    mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
    mask = (mask == 0)

    return cross_fused_attn(q, kv, mask, dropout_rng, **kwargs)


150
151
152
153
154
155
156
157
158
159
160
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn():
    """Tests for transformer_engine.jax.fused_attn.self_fused_attn"""

    @staticmethod
161
162
163
164
165
    def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
                      head_dim, pad_ratio):
        if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
            pytest.skip("Arbitrary seqlen backend hasn't support padded input.")

166
167
        if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
                                              attn_mask_type, dropout_probability, s, s, head_dim):
168
            pytest.skip("Unsupported inputs combination or device compute capability.")
169

Tim Moon's avatar
Tim Moon committed
170
171
        compute_capability = get_device_compute_capability(0)
        if (backend == Backend.Max512
172
                and not (compute_capability == 80 or compute_capability >= 90)):
Tim Moon's avatar
Tim Moon committed
173
174
175
            pytest.skip("Unsupported compute capability for "
                        "fused attention with <=512 sequence length")

176
177
178
179
180
181
182
    def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
                    dropout_probability, dtype, is_training, pad_ratio):
        """Setup the test inputs"""
        self.__class__._check_inputs(s,
                                     attn_bias_type=attn_bias_type,
                                     attn_mask_type=attn_mask_type,
                                     backend=backend,
183
184
185
                                     dropout_probability=dropout_probability,
                                     dtype=dtype,
                                     head_dim=d,
186
                                     pad_ratio=pad_ratio)
187
188
189
190
191
192
193
194
195
196
197
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        qkv_shape = (b, s, 3, h, d)
        bias_shape = (1, h, s, s)

        pad_len = int(s * pad_ratio)
        self.valid_len = s - pad_len

        min_val, max_val = -1, 1
        self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
198
199

        with_bias = attn_bias_type != AttnBiasType.NO_BIAS
200
201
        self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val,
                                       max_val) if with_bias else None
202
203
204
205
206

        self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))),
                                       axis=-1)
        self.kv_token = self.q_token

207
        self.scaling_factor = 1. / sqrt(d)
208
        self.dropout_probability = dropout_probability
209
        self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
210
        self.attn_bias_type = attn_bias_type
211
        self.attn_mask_type = attn_mask_type
212
        self.is_training = is_training
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
                     dtype, is_training, pad_ratio):
        """
        Test forward without using JIT
        """
        self._set_inputs(b,
                         s,
                         h,
                         d,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         backend=backend,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
                         is_training=is_training,
                         pad_ratio=pad_ratio)
230
231
232
233
234
235
236
237
238
239

        primitive_out = customcall_self_fused_attn(self.qkv,
                                                   self.bias,
                                                   self.q_token,
                                                   self.kv_token,
                                                   self.dropout_rng,
                                                   attn_bias_type=self.attn_bias_type,
                                                   attn_mask_type=attn_mask_type,
                                                   scaling_factor=self.scaling_factor,
                                                   dropout_probability=self.dropout_probability,
240
                                                   is_training=self.is_training)
241

242
243
244
245
246
247
248
249
250
        reference_out = jax_self_attn(self.qkv,
                                      self.bias,
                                      self.q_token,
                                      self.kv_token,
                                      self.dropout_rng,
                                      attn_mask_type=attn_mask_type,
                                      scaling_factor=self.scaling_factor,
                                      dropout_probability=self.dropout_probability,
                                      is_training=self.is_training)
251
252
253
254

        ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
        pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)

255
256
257
258
        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if is_training and dropout_probability > 0.:
            return

259
260
261
262
263
264
265
266
        np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
                                   jnp.asarray(ref_valid, np.float32),
                                   rtol=1e-4,
                                   atol=1e-2)

        np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
                                   jnp.zeros_like(pri_invalid, jnp.float32))

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
                              dropout_probability, dtype, is_training, pad_ratio):
        """
        Test forward, backward, and autodiff by jax.value_and_grad
        """
        if not is_training:
            pytest.skip(f"Backward doesn't support {is_training=}")

        self._set_inputs(b,
                         s,
                         h,
                         d,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         backend=backend,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
                         is_training=is_training,
                         pad_ratio=pad_ratio)

        def grad_func(fused_attn_func, *args, **kwargs):
288
289
290
291
292
            # Gradient is small, use a gradient multiplier to amplify the graident
            gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
            if attn_mask_type == AttnMaskType.CAUSAL_MASK:
                gradient_multiplier = gradient_multiplier / 10
            # Keep only valid result for the gradient
293
294
295
296
            # fused_attn output has shape (b, s, h, d)
            valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), (self.valid_len,),
                                                axis=1)
            return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
297
298
299
300
301
302
303
                    gradient_multiplier).astype(dtype)

        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_probability,
304
            'is_training': self.is_training
305
306
307
308
309
310
311
312
313
314
315
316
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
                lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
                    customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
                ), (0, 1)))

        jitted_reference = jit(
            value_and_grad(
                lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
317
                    jax_self_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
318
319

        primitive_out, (primitive_dqkv,
320
                        primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
321
322
323
                                                            self.kv_token, self.dropout_rng)

        reference_out, (reference_dqkv,
324
                        reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token,
325
326
                                                            self.kv_token, self.dropout_rng)

327
328
329
330
        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if dropout_probability > 0.:
            return

331
332
333
334
335
336
337
338
339
340
        np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
                                   jnp.asarray(reference_out, np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,),
                                                                 axis=1)
        valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
                                                                 axis=1)

341
342
343
344
        valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
            valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
        valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
            valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
345

346
347
348
        np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
        np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
        np.testing.assert_allclose(valid_primitive_dv, valid_reference_dv, rtol=1e-4, atol=1e-5)
349
350
351
352
353
354

        assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv)

        # Padded part should be 0s
        assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv))

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            # dbias valid part
            np.testing.assert_allclose(
                jnp.asarray(primitive_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
                jnp.asarray(reference_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
                rtol=1e-4,
                atol=3e-5)

            # dbias padded part
            np.testing.assert_allclose(
                jnp.asarray(primitive_dbias[:, :, self.valid_len:, self.valid_len:], np.float32),
                jnp.asarray(reference_dbias[:, :, self.valid_len:, self.valid_len:], np.float32))

            assert jnp.allclose(
                primitive_dbias[:, :, self.valid_len:, self.valid_len:],
                jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:]))
371
372


373
@pytest.mark.skipif(get_device_compute_capability(0) not in [80, 90],
374
                    reason="Fused attention kernel is not supported.")
375
376
377
378
379
380
381
382
383
384
385
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0.3])
class TestCrossFusedAttn():
    """Tests for transformer_engine.jax.fused_attn.cross_fused_attn"""

    def _set_inputs(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
                    is_training, pad_ratio):
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        q_shape = (b, s_q, h, d)
        kv_shape = (b, s_kv, 2, h, d)
        q_pad_len = int(s_q * pad_ratio)
        kv_pad_len = int(s_kv * pad_ratio)
        self.q_valid_len = s_q - q_pad_len
        self.kv_valid_len = s_kv - kv_pad_len

        min_val, max_val = -1, 1
        self.q = jax.random.uniform(subkeys[0], q_shape, dtype, min_val, max_val)
        self.kv = jax.random.uniform(subkeys[1], kv_shape, dtype, min_val, max_val)

        self.q_token = jnp.concatenate((jnp.ones((b, self.q_valid_len)), jnp.zeros((b, q_pad_len))),
                                       axis=-1)
        self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
            (b, kv_pad_len))),
                                        axis=-1)
405
        self.scaling_factor = 1. / sqrt(d)
406
407
        self.dropout_probability = dropout_probability
        self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
408
        self.attn_bias_type = AttnBiasType.NO_BIAS
409
        self.attn_mask_type = attn_mask_type
410
        self.is_training = is_training
411

412
413
    def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
                     is_training, pad_ratio):
414
415
416
417
418
419
420
421
422
423
424
425
426
        """
        Test forward without using JIT
        """
        self._set_inputs(b,
                         s_q,
                         s_kv,
                         h,
                         d,
                         attn_mask_type=attn_mask_type,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
                         is_training=is_training,
                         pad_ratio=pad_ratio)
427
428
429
430
431
432
433
434
435
436

        primitive_out = customcall_cross_fused_attn(self.q,
                                                    self.kv,
                                                    self.q_token,
                                                    self.kv_token,
                                                    self.dropout_rng,
                                                    attn_bias_type=self.attn_bias_type,
                                                    attn_mask_type=attn_mask_type,
                                                    scaling_factor=self.scaling_factor,
                                                    dropout_probability=self.dropout_probability,
437
                                                    is_training=self.is_training)
438

439
440
441
442
443
444
445
446
447
448
449
450
451
        reference_out = jax_cross_attn(self.q,
                                       self.kv,
                                       self.q_token,
                                       self.kv_token,
                                       self.dropout_rng,
                                       attn_mask_type=attn_mask_type,
                                       scaling_factor=self.scaling_factor,
                                       dropout_probability=self.dropout_probability,
                                       is_training=self.is_training)

        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if is_training and dropout_probability > 0.:
            return
452
453
454
455
456
457
458
459
460
461
462
463

        ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
        pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)

        np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
                                   jnp.asarray(ref_valid, np.float32),
                                   rtol=1e-4,
                                   atol=2e-3)

        np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
                                   jnp.zeros_like(pri_invalid, jnp.float32))

464
465
    def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
                              is_training, pad_ratio):
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        """
        Test forward, backward, and autodiff by jax.value_and_grad
        """
        if not is_training:
            pytest.skip(f"Backward doesn't support {is_training=}")

        self._set_inputs(b,
                         s_q,
                         s_kv,
                         h,
                         d,
                         attn_mask_type=attn_mask_type,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
                         is_training=is_training,
                         pad_ratio=pad_ratio)

        def grad_func(fused_attn_func, *args, **kwargs):
484
485
486
487
488
            # Gradient is small, use a gradient multiplier to amplify the graident
            gradient_multiplier = 10000
            if attn_mask_type == AttnMaskType.CAUSAL_MASK:
                gradient_multiplier = gradient_multiplier / 10
            # Keep only valid result for the gradient
489
490
491
492
493
            # fused_attn output has shape (b, s_q, h, d)
            valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
                                                (self.q_valid_len,),
                                                axis=1)
            return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
494
495
496
497
498
499
500
                    gradient_multiplier).astype(dtype)

        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_probability,
501
            'is_training': self.is_training
502
503
504
505
506
507
508
509
510
511
512
513
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
                lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
                    customcall_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs),
                (0, 1)))

        jitted_reference = jit(
            value_and_grad(
                lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
514
                    jax_cross_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
515
516
517
518
519
520
521
522
523

        primitive_out, (primitive_dq,
                        primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
                                                          self.kv_token, self.dropout_rng)

        reference_out, (reference_dq,
                        reference_dkv) = jitted_reference(self.q, self.kv, self.q_token,
                                                          self.kv_token, self.dropout_rng)

524
525
526
527
        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if dropout_probability > 0.:
            return

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
                                   jnp.asarray(reference_out, np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        valid_primitive_dq, invalid_primitive_dq = jnp.split(primitive_dq, (self.q_valid_len,),
                                                             axis=1)
        valid_reference_dq, invalid_reference_dq = jnp.split(reference_dq, (self.q_valid_len,),
                                                             axis=1)

        valid_primitive_dkv, invalid_primitive_dkv = jnp.split(primitive_dkv, (self.kv_valid_len,),
                                                               axis=1)
        valid_reference_dkv, invalid_reference_dkv = jnp.split(reference_dkv, (self.kv_valid_len,),
                                                               axis=1)

        # dQ
        np.testing.assert_allclose(jnp.asarray(valid_primitive_dq, np.float32),
                                   jnp.asarray(valid_reference_dq, np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        # dK
        np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 0], np.float32),
                                   jnp.asarray(valid_reference_dkv[:, :, 0], np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        # dV
        np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 1], np.float32),
                                   jnp.asarray(valid_reference_dkv[:, :, 1], np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        assert jnp.allclose(invalid_primitive_dq, invalid_reference_dq)
        assert jnp.allclose(invalid_primitive_dkv, invalid_reference_dkv)

        # Padded part should be 0s
        assert jnp.allclose(invalid_primitive_dq, jnp.zeros_like(invalid_primitive_dq))
        assert jnp.allclose(invalid_primitive_dkv, jnp.zeros_like(invalid_primitive_dkv))