test_fused_attn.py 24.4 KB
Newer Older
1
2
3
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
4
"""Tests for fused attention"""
5

6
7
8
import os
from enum import Enum
from math import sqrt
9
10
11
12
13
14
15
16
17
18
19
20

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from flax.linen import combine_masks
from flax.linen import dot_product_attention
from flax.linen import make_attention_mask
from flax.linen import make_causal_mask
from jax import value_and_grad, jit

21
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
22
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
23
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
24
from transformer_engine_jax import get_device_compute_capability    # pylint: disable=wrong-import-order
25
26
27
28

# Type annotations
Array = jnp.ndarray

29

30
31
32
33
34
35
36
37
38
39
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
    """
    Clear all live arrays to keep the resource clean
    """
    yield
    for arr in jax.live_arrays():
        arr.delete()


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class Backend(Enum):
    """
    Fused attn backend.
    Unit tests only, transformer will auto dispatch to the best backend
    """
    Max512 = "0"
    Arbitrary = "1"


@pytest.fixture(name="backend", params=[Backend.Max512, Backend.Arbitrary])
def fixture_backend(request):
    """
    Fixture of setting up/tearing down backend
    """
    backend = request.param
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
    yield backend
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""


SELF_CASES = [(32, 512, 16, 64), (32, 128, 16, 64), (4, 2048, 12, 64)]
61
62
63
64
CROSS_CASES = [(32, 128, 512, 16, 64)]
DTYPES = [jnp.bfloat16, jnp.float16]


65
66
67
68
69
70
71
def is_causal_mask(mask: AttnMaskType):
    """
    Check if the mask is a causal mask
    """
    return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


72
def make_decoder_mask(tokens: Array) -> Array:
73
74
75
    """
    Create padded causal mask
    """
76
77
78
79
80
    causal_mask = make_causal_mask(tokens)
    padding_mask = make_attention_mask(tokens > 0, tokens > 0)
    return combine_masks(causal_mask, padding_mask)


81
82
83
84
def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
    """
    Self attention with JAX native implementation
    """
85
    attn_mask_type = kwargs['attn_mask_type']
86
    if is_causal_mask(attn_mask_type):
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        mask = make_decoder_mask(q_token)
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    query, key, value = jnp.split(qkv, [1, 2], axis=-3)
    query = jnp.squeeze(query)
    key = jnp.squeeze(key)
    value = jnp.squeeze(value)

    output = dot_product_attention(query,
                                   key,
                                   value,
                                   bias=bias,
                                   mask=mask,
101
                                   deterministic=not kwargs['is_training'],
102
103
                                   dropout_rate=kwargs['dropout_probability'],
                                   dropout_rng=dropout_rng,
104
105
                                   dtype=jnp.float32)
    return output.astype(qkv.dtype)
106
107


108
109
110
111
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
    """
    Cross attention with JAX native implementation
    """
112
113
114
    assert q.dtype == kv.dtype

    attn_mask_type = kwargs['attn_mask_type']
115
    if is_causal_mask(attn_mask_type):
116
117
118
119
120
121
122
123
124
125
126
127
128
        raise NotImplementedError
    mask = make_attention_mask(q_token > 0, kv_token > 0)

    query = q
    key, value = jnp.split(kv, [1], axis=-3)
    key = jnp.squeeze(key)
    value = jnp.squeeze(value)

    output = dot_product_attention(query,
                                   key,
                                   value,
                                   bias=None,
                                   mask=mask,
129
                                   deterministic=not kwargs['is_training'],
130
131
                                   dropout_rate=kwargs['dropout_probability'],
                                   dropout_rng=dropout_rng,
132
133
                                   dtype=jnp.float32)
    return output.astype(q.dtype)
134
135
136


def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
137
138
139
    """
    Self fused attention
    """
140
141
    attn_mask_type = kwargs['attn_mask_type']
    if is_causal_mask(attn_mask_type):
142
143
144
145
146
147
148
149
150
151
152
        mask = make_decoder_mask(q_token)
    else:
        mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
    mask = (mask == 0)

    return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs)


def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
153
154
155
    """
    Cross fused attention
    """
156
157
    assert q.dtype == kv.dtype

158
159
    attn_mask_type = kwargs['attn_mask_type']
    if is_causal_mask(attn_mask_type):
160
161
162
163
164
165
        raise NotImplementedError
    mask = make_attention_mask(q_token > 0, kv_token > 0)

    # mask invert
    mask = (mask == 0)

166
    return cross_fused_attn(q, kv, None, mask, dropout_rng, **kwargs)
167
168


169
170
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
171
172
173
174
@pytest.mark.parametrize('attn_mask_type', [
    AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
    AttnMaskType.PADDING_CAUSAL_MASK
])
175
176
177
178
179
180
181
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
class TestSelfFusedAttn():
    """Tests for transformer_engine.jax.fused_attn.self_fused_attn"""

    @staticmethod
182
    def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
183
184
185
                      head_dim):

        assert isinstance(backend, Backend)
186

187
188
        if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
                                              attn_mask_type, dropout_probability, s, s, head_dim):
189
            pytest.skip("Unsupported inputs combination or device compute capability.")
190
191

    def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
192
                    dropout_probability, dtype, is_training):
193
194
195
196
197
        """Setup the test inputs"""
        self.__class__._check_inputs(s,
                                     attn_bias_type=attn_bias_type,
                                     attn_mask_type=attn_mask_type,
                                     backend=backend,
198
199
                                     dropout_probability=dropout_probability,
                                     dtype=dtype,
200
201
202
203
204
205
206
                                     head_dim=d)

        if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
            pad_ratio = 0.0
        else:
            pad_ratio = 0.3

207
208
209
210
211
212
213
214
215
216
217
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        qkv_shape = (b, s, 3, h, d)
        bias_shape = (1, h, s, s)

        pad_len = int(s * pad_ratio)
        self.valid_len = s - pad_len

        min_val, max_val = -1, 1
        self.qkv = jax.random.uniform(subkeys[0], qkv_shape, dtype, min_val, max_val)
218
219

        with_bias = attn_bias_type != AttnBiasType.NO_BIAS
220
221
        self.bias = jax.random.uniform(subkeys[1], bias_shape, dtype, min_val,
                                       max_val) if with_bias else None
222
223
224
225
226

        self.q_token = jnp.concatenate((jnp.ones((b, self.valid_len)), jnp.zeros((b, pad_len))),
                                       axis=-1)
        self.kv_token = self.q_token

227
        self.scaling_factor = 1. / sqrt(d)
228
        self.dropout_probability = dropout_probability
229
        self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
230
        self.attn_bias_type = attn_bias_type
231
        self.attn_mask_type = attn_mask_type
232
        self.is_training = is_training
233

234
    def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
235
                     dtype, is_training):
236
237
238
239
240
241
242
243
244
245
246
247
        """
        Test forward without using JIT
        """
        self._set_inputs(b,
                         s,
                         h,
                         d,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         backend=backend,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
248
                         is_training=is_training)
249
250
251
252
253
254
255
256
257
258

        primitive_out = customcall_self_fused_attn(self.qkv,
                                                   self.bias,
                                                   self.q_token,
                                                   self.kv_token,
                                                   self.dropout_rng,
                                                   attn_bias_type=self.attn_bias_type,
                                                   attn_mask_type=attn_mask_type,
                                                   scaling_factor=self.scaling_factor,
                                                   dropout_probability=self.dropout_probability,
259
                                                   is_training=self.is_training)
260

261
262
263
264
265
266
267
268
269
        reference_out = jax_self_attn(self.qkv,
                                      self.bias,
                                      self.q_token,
                                      self.kv_token,
                                      self.dropout_rng,
                                      attn_mask_type=attn_mask_type,
                                      scaling_factor=self.scaling_factor,
                                      dropout_probability=self.dropout_probability,
                                      is_training=self.is_training)
270
271
272
273

        ref_valid, _ = jnp.split(reference_out, (self.valid_len,), axis=1)
        pri_valid, pri_invalid = jnp.split(primitive_out, (self.valid_len,), axis=1)

274
275
276
277
        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if is_training and dropout_probability > 0.:
            return

278
279
280
281
282
283
284
285
        np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
                                   jnp.asarray(ref_valid, np.float32),
                                   rtol=1e-4,
                                   atol=1e-2)

        np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
                                   jnp.zeros_like(pri_invalid, jnp.float32))

286
    def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
287
                              dropout_probability, dtype, is_training):
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        """
        Test forward, backward, and autodiff by jax.value_and_grad
        """
        if not is_training:
            pytest.skip(f"Backward doesn't support {is_training=}")

        self._set_inputs(b,
                         s,
                         h,
                         d,
                         attn_bias_type=attn_bias_type,
                         attn_mask_type=attn_mask_type,
                         backend=backend,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
303
                         is_training=is_training)
304
305

        def grad_func(fused_attn_func, *args, **kwargs):
306
307
            # Gradient is small, use a gradient multiplier to amplify the graident
            gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
308
            if is_causal_mask(attn_mask_type):
309
310
                gradient_multiplier = gradient_multiplier / 10
            # Keep only valid result for the gradient
311
312
313
314
            # fused_attn output has shape (b, s, h, d)
            valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs), (self.valid_len,),
                                                axis=1)
            return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
315
316
317
318
319
320
321
                    gradient_multiplier).astype(dtype)

        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_probability,
322
            'is_training': self.is_training
323
324
325
326
327
328
329
330
331
332
333
334
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
                lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
                    customcall_self_fused_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs
                ), (0, 1)))

        jitted_reference = jit(
            value_and_grad(
                lambda qkv, bias, q_token, kv_token, dropout_rng: grad_func(
335
                    jax_self_attn, qkv, bias, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
336
337

        primitive_out, (primitive_dqkv,
338
                        primitive_dbias) = jitted_primitive(self.qkv, self.bias, self.q_token,
339
340
341
                                                            self.kv_token, self.dropout_rng)

        reference_out, (reference_dqkv,
342
                        reference_dbias) = jitted_reference(self.qkv, self.bias, self.q_token,
343
344
                                                            self.kv_token, self.dropout_rng)

345
346
347
348
        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if dropout_probability > 0.:
            return

349
350
351
352
353
        np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
                                   jnp.asarray(reference_out, np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

354
355
356
357
        valid_primitive_dqkv, invalid_primitive_dqkv = \
            jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
        valid_reference_dqkv, invalid_reference_dqkv = \
            jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
358

359
360
361
362
        valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
            jnp.split(valid_primitive_dqkv, 3, axis=2)
        valid_reference_dq, valid_reference_dk, valid_reference_dv = \
            jnp.split(valid_reference_dqkv, 3, axis=2)
363

364
365
366
        np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
        np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
        np.testing.assert_allclose(valid_primitive_dv, valid_reference_dv, rtol=1e-4, atol=1e-5)
367
368
369
370
371
372

        assert jnp.allclose(invalid_primitive_dqkv, invalid_reference_dqkv)

        # Padded part should be 0s
        assert jnp.allclose(invalid_primitive_dqkv, jnp.zeros_like(invalid_primitive_dqkv))

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
        if self.attn_bias_type != AttnBiasType.NO_BIAS:
            # dbias valid part
            np.testing.assert_allclose(
                jnp.asarray(primitive_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
                jnp.asarray(reference_dbias[:, :, :self.valid_len, :self.valid_len], np.float32),
                rtol=1e-4,
                atol=3e-5)

            # dbias padded part
            np.testing.assert_allclose(
                jnp.asarray(primitive_dbias[:, :, self.valid_len:, self.valid_len:], np.float32),
                jnp.asarray(reference_dbias[:, :, self.valid_len:, self.valid_len:], np.float32))

            assert jnp.allclose(
                primitive_dbias[:, :, self.valid_len:, self.valid_len:],
                jnp.zeros_like(primitive_dbias[:, :, self.valid_len:, self.valid_len:]))
389
390


391
@pytest.mark.skipif(get_device_compute_capability(0) not in [80, 90],
392
                    reason="Fused attention kernel is not supported.")
393
394
395
396
397
398
399
400
401
402
403
@pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_CASES)
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0.3])
class TestCrossFusedAttn():
    """Tests for transformer_engine.jax.fused_attn.cross_fused_attn"""

    def _set_inputs(self, b, s_q, s_kv, h, d, *, attn_mask_type, dropout_probability, dtype,
                    is_training, pad_ratio):
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        key = jax.random.PRNGKey(0)
        subkeys = jax.random.split(key, 2)

        q_shape = (b, s_q, h, d)
        kv_shape = (b, s_kv, 2, h, d)
        q_pad_len = int(s_q * pad_ratio)
        kv_pad_len = int(s_kv * pad_ratio)
        self.q_valid_len = s_q - q_pad_len
        self.kv_valid_len = s_kv - kv_pad_len

        min_val, max_val = -1, 1
        self.q = jax.random.uniform(subkeys[0], q_shape, dtype, min_val, max_val)
        self.kv = jax.random.uniform(subkeys[1], kv_shape, dtype, min_val, max_val)

        self.q_token = jnp.concatenate((jnp.ones((b, self.q_valid_len)), jnp.zeros((b, q_pad_len))),
                                       axis=-1)
        self.kv_token = jnp.concatenate((jnp.ones((b, self.kv_valid_len)), jnp.zeros(
            (b, kv_pad_len))),
                                        axis=-1)
423
        self.scaling_factor = 1. / sqrt(d)
424
425
        self.dropout_probability = dropout_probability
        self.dropout_rng = jax.random.PRNGKey(0) if self.dropout_probability > 0 else None
426
        self.attn_bias_type = AttnBiasType.NO_BIAS
427
        self.attn_mask_type = attn_mask_type
428
        self.is_training = is_training
429

430
431
    def test_forward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
                     is_training, pad_ratio):
432
433
434
435
436
437
438
439
440
441
442
443
444
        """
        Test forward without using JIT
        """
        self._set_inputs(b,
                         s_q,
                         s_kv,
                         h,
                         d,
                         attn_mask_type=attn_mask_type,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
                         is_training=is_training,
                         pad_ratio=pad_ratio)
445
446
447
448
449
450
451
452
453
454

        primitive_out = customcall_cross_fused_attn(self.q,
                                                    self.kv,
                                                    self.q_token,
                                                    self.kv_token,
                                                    self.dropout_rng,
                                                    attn_bias_type=self.attn_bias_type,
                                                    attn_mask_type=attn_mask_type,
                                                    scaling_factor=self.scaling_factor,
                                                    dropout_probability=self.dropout_probability,
455
                                                    is_training=self.is_training)
456

457
458
459
460
461
462
463
464
465
466
467
468
469
        reference_out = jax_cross_attn(self.q,
                                       self.kv,
                                       self.q_token,
                                       self.kv_token,
                                       self.dropout_rng,
                                       attn_mask_type=attn_mask_type,
                                       scaling_factor=self.scaling_factor,
                                       dropout_probability=self.dropout_probability,
                                       is_training=self.is_training)

        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if is_training and dropout_probability > 0.:
            return
470
471
472
473
474
475
476
477
478
479
480
481

        ref_valid, _ = jnp.split(reference_out, (self.q_valid_len,), axis=1)
        pri_valid, pri_invalid = jnp.split(primitive_out, (self.q_valid_len,), axis=1)

        np.testing.assert_allclose(jnp.asarray(pri_valid, np.float32),
                                   jnp.asarray(ref_valid, np.float32),
                                   rtol=1e-4,
                                   atol=2e-3)

        np.testing.assert_allclose(jnp.asarray(pri_invalid, jnp.float32),
                                   jnp.zeros_like(pri_invalid, jnp.float32))

482
483
    def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_probability, dtype,
                              is_training, pad_ratio):
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        """
        Test forward, backward, and autodiff by jax.value_and_grad
        """
        if not is_training:
            pytest.skip(f"Backward doesn't support {is_training=}")

        self._set_inputs(b,
                         s_q,
                         s_kv,
                         h,
                         d,
                         attn_mask_type=attn_mask_type,
                         dropout_probability=dropout_probability,
                         dtype=dtype,
                         is_training=is_training,
                         pad_ratio=pad_ratio)

        def grad_func(fused_attn_func, *args, **kwargs):
502
            # Gradient is small, use a gradient multiplier to amplify the graident
503
            gradient_multiplier = 1e4
504
            # Keep only valid result for the gradient
505
506
507
508
509
            # fused_attn output has shape (b, s_q, h, d)
            valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
                                                (self.q_valid_len,),
                                                axis=1)
            return (jnp.mean(valid_fused_attn_ret, dtype=jnp.float32) *
510
511
512
513
514
515
516
                    gradient_multiplier).astype(dtype)

        kwargs = {
            'attn_bias_type': self.attn_bias_type,
            'attn_mask_type': attn_mask_type,
            'scaling_factor': self.scaling_factor,
            'dropout_probability': self.dropout_probability,
517
            'is_training': self.is_training
518
519
520
521
522
523
524
525
526
527
528
529
        }

        # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
        jitted_primitive = jit(
            value_and_grad(
                lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
                    customcall_cross_fused_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs),
                (0, 1)))

        jitted_reference = jit(
            value_and_grad(
                lambda q, kv, q_token, kv_token, dropout_rng: grad_func(
530
                    jax_cross_attn, q, kv, q_token, kv_token, dropout_rng, **kwargs), (0, 1)))
531
532
533
534
535
536
537
538
539

        primitive_out, (primitive_dq,
                        primitive_dkv) = jitted_primitive(self.q, self.kv, self.q_token,
                                                          self.kv_token, self.dropout_rng)

        reference_out, (reference_dq,
                        reference_dkv) = jitted_reference(self.q, self.kv, self.q_token,
                                                          self.kv_token, self.dropout_rng)

540
541
542
543
        # Dropout can't get the bitmatch result, skip the elementwise comparison
        if dropout_probability > 0.:
            return

544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        np.testing.assert_allclose(jnp.asarray(primitive_out, np.float32),
                                   jnp.asarray(reference_out, np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        valid_primitive_dq, invalid_primitive_dq = jnp.split(primitive_dq, (self.q_valid_len,),
                                                             axis=1)
        valid_reference_dq, invalid_reference_dq = jnp.split(reference_dq, (self.q_valid_len,),
                                                             axis=1)

        valid_primitive_dkv, invalid_primitive_dkv = jnp.split(primitive_dkv, (self.kv_valid_len,),
                                                               axis=1)
        valid_reference_dkv, invalid_reference_dkv = jnp.split(reference_dkv, (self.kv_valid_len,),
                                                               axis=1)

        # dQ
        np.testing.assert_allclose(jnp.asarray(valid_primitive_dq, np.float32),
                                   jnp.asarray(valid_reference_dq, np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        # dK
        np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 0], np.float32),
                                   jnp.asarray(valid_reference_dkv[:, :, 0], np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        # dV
        np.testing.assert_allclose(jnp.asarray(valid_primitive_dkv[:, :, 1], np.float32),
                                   jnp.asarray(valid_reference_dkv[:, :, 1], np.float32),
                                   rtol=1e-4,
                                   atol=1e-5)

        assert jnp.allclose(invalid_primitive_dq, invalid_reference_dq)
        assert jnp.allclose(invalid_primitive_dkv, invalid_reference_dkv)

        # Padded part should be 0s
        assert jnp.allclose(invalid_primitive_dq, jnp.zeros_like(invalid_primitive_dq))
        assert jnp.allclose(invalid_primitive_dkv, jnp.zeros_like(invalid_primitive_dkv))