test_mamba_ssm.py 28.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

9
10
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401
11
from vllm.attention.backends.utils import PAD_SLOT_ID
12
13
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
    selective_scan_fn, selective_state_update)
14
from vllm.platforms import current_platform
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


def selective_state_update_ref(state,
                               x,
                               dt,
                               A,
                               B,
                               C,
                               D=None,
                               z=None,
                               dt_bias=None,
                               dt_softplus=False):
    """
    Argument:
        state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
        x: (batch, dim) or (batch, nheads, dim)
        dt: (batch, dim) or (batch, nheads, dim)
        A: (dim, dstate) or (nheads, dim, dstate)
        B: (batch, dstate) or (batch, ngroups, dstate)
        C: (batch, dstate) or (batch, ngroups, dstate)
        D: (dim,) or (nheads, dim)
        z: (batch, dim) or (batch, nheads, dim)
        dt_bias: (dim,) or (nheads, dim)
    Return:
        out: (batch, dim) or (batch, nheads, dim)
    """
    has_heads = state.dim() > 3
    if state.dim() == 3:
        state = state.unsqueeze(1)
    if x.dim() == 2:
        x = x.unsqueeze(1)
    if dt.dim() == 2:
        dt = dt.unsqueeze(1)
    if A.dim() == 2:
        A = A.unsqueeze(0)
    if B.dim() == 2:
        B = B.unsqueeze(1)
    if C.dim() == 2:
        C = C.unsqueeze(1)
    if D is not None and D.dim() == 1:
        D = D.unsqueeze(0)
    if z is not None and z.dim() == 2:
        z = z.unsqueeze(1)
    if dt_bias is not None and dt_bias.dim() == 1:
        dt_bias = dt_bias.unsqueeze(0)
    batch, nheads, dim, dstate = state.shape
    assert x.shape == (batch, nheads, dim)
    assert dt.shape == x.shape
    assert A.shape == (nheads, dim, dstate)
    ngroups = B.shape[1]
    assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
    assert B.shape == (batch, ngroups, dstate)
    assert C.shape == B.shape
    if D is not None:
        assert D.shape == (nheads, dim)
    if z is not None:
        assert z.shape == x.shape
    if dt_bias is not None:
        assert dt_bias.shape == (nheads, dim)
        dt = dt + dt_bias
    dt = F.softplus(dt) if dt_softplus else dt
    dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
                   A)  # (batch, nheads, dim, dstate)
    B = repeat(B, "b g n -> b (g h) n",
               h=nheads // ngroups)  # (batch, nheads, dstate)
    C = repeat(C, "b g n -> b (g h) n",
               h=nheads // ngroups)  # (batch, nheads, dstate)
    dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
        B, "b h n -> b h 1 n")  # (batch, nheads, dim, dstate)
    state.copy_(state * dA +
                dB * rearrange(x, "b h d -> b h d 1"))  # (batch, dim, dstate
    out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
    if D is not None:
        out += (x * D).to(out.dtype)
    out = (out if z is None else out * F.silu(z)).to(x.dtype)
    if not has_heads:
        out = out.squeeze(1)
    return out


def selective_scan_ref(u,
                       delta,
                       A,
                       B,
                       C,
                       D=None,
                       z=None,
                       delta_bias=None,
                       delta_softplus=False,
                       return_last_state=False,
105
106
                       prev_state=None,
                       final_state_out=None):
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
    prev_state: r(B D N), fp32

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    B = B.float()
    C = C.float()
    x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
    ys = []
    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
    if not is_variable_B:
        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
    else:
        if B.dim() == 3:
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    for i in range(u.shape[2]):
147
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
148
149
150
151
152
153
154
155
        if not is_variable_C:
            y = torch.einsum('bdn,dn->bd', x, C)
        else:
            if C.dim() == 3:
                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
            else:
                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
        if i == u.shape[2] - 1:
156
157
158
159
            if final_state_out is None:
                final_state_out = x
            else:
                final_state_out.copy_(x)
160
161
162
163
164
165
        ys.append(y)
    y = torch.stack(ys, dim=2)  # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
166
    return out if not return_last_state else (out, final_state_out)
167
168


169
170
171
172
173
174
175
176
177
def selective_scan_opcheck_fn(u,
                              delta,
                              A,
                              B,
                              C,
                              D=None,
                              z=None,
                              delta_bias=None,
                              delta_softplus=False,
178
179
180
                              cu_seq_len=None,
                              cache_indices=None,
                              has_initial_state=None,
181
182
                              ssm_states=None,
                              pad_slot_id=PAD_SLOT_ID):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate).
    """
    if u.stride(-1) != 1:
        u = u.contiguous()
    if delta.stride(-1) != 1:
        delta = delta.contiguous()
    if D is not None:
        D = D.contiguous()
    if B.stride(-1) != 1:
        B = B.contiguous()
    if C.stride(-1) != 1:
        C = C.contiguous()
    if z is not None and z.stride(-1) != 1:
        z = z.contiguous()
198
    if B.dim() == 3 and cu_seq_len is None:
199
        B = B.unsqueeze(1)
200
201
202
    if B.dim() == 2 and cu_seq_len is not None:
        B = B.unsqueeze(0)
    if C.dim() == 3 and cu_seq_len is None:
203
        C = C.unsqueeze(1)
204
205
    if C.dim() == 2 and cu_seq_len is not None:
        C = C.unsqueeze(0)
206
207
208
209

    # Disable test_autograd_registration for now as it seems to trigger
    # a bogus error.
    opcheck(torch.ops._C.selective_scan_fwd,
210
            (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
211
             cache_indices, has_initial_state, ssm_states, pad_slot_id),
212
213
214
            test_utils=["test_schema", "test_faketensor"])


215
@pytest.mark.parametrize('wtype', [torch.float32])
216
217
@pytest.mark.parametrize('itype',
                         [torch.float32, torch.float16, torch.bfloat16])
218
219
220
221
222
223
224
225
226
227
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
228
229
                        has_z, has_delta_bias, delta_softplus, seqlen, itype,
                        wtype, scan_chunks):
230
231
232
233
234
235
236
237
238
239
240
    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
        pytest.skip()  # This config is not applicable
    device = 'cuda'
    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    rtolw, atolw = (1e-3, 1e-3)
    if has_z:  # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
    # set seed
241
    current_platform.seed_everything(0)
242
    batch_size = 1
243
244
245
    dim = 4
    dstate = 8
    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
246
    A_ref = A.clone()
247
248
249
250
251
252
253
254
255
    if not is_variable_B:
        B_shape = [dim, dstate]
    elif varBC_groups == 1:
        B_shape = [batch_size, dstate, seqlen]
    else:
        B_shape = [batch_size, varBC_groups, dstate, seqlen]
    B = torch.randn(B_shape,
                    device=device,
                    dtype=wtype if not is_variable_B else itype)
256
    B_ref = B.clone()
257
258
259
260
261
262
263
264
265
    if not is_variable_C:
        C_shape = [dim, dstate]
    elif varBC_groups == 1:
        C_shape = [batch_size, dstate, seqlen]
    else:
        C_shape = [batch_size, varBC_groups, dstate, seqlen]
    C = torch.randn(C_shape,
                    device=device,
                    dtype=wtype if not is_variable_C else itype)
266
    C_ref = C.clone()
267
    D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
268
    D_ref = D.clone()
269
270
    z = torch.randn(batch_size, dim, seqlen, device=device,
                    dtype=itype) if has_z else None
271
    z_ref = z.clone() if has_z else None
272
273
274
    delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
                  ) if has_delta_bias else None
    u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
275
    u_ref = u.clone()
276
277
    delta = (0.5 *
             torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
278
279
280
281
282
283
284
    delta_ref = delta.clone()
    state_shape = (batch_size, u.shape[1], int(A.shape[1]))
    state = torch.randn(state_shape,
                        device=u.device,
                        dtype=itype,
                        requires_grad=False)
    state_ref = state.clone()
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    out = None
    out_ref = None
    outs = []
    for c in range(scan_chunks):
        chunked_prompt_len = seqlen // scan_chunks
        chunk_start = chunked_prompt_len * c
        chunk_end = chunked_prompt_len * (c + 1)
        if c == scan_chunks - 1:
            chunk_end = seqlen
        _B = B
        if is_variable_B:
            _B = B[..., chunk_start:chunk_end]
        _C = C
        if is_variable_B:
            _C = C[..., chunk_start:chunk_end]
        _z = z
        if has_z:
            assert z is not None
            _z = z[..., chunk_start:chunk_end]
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        out = selective_scan_fn(
            u[..., chunk_start:chunk_end],
            state,
            delta[..., chunk_start:chunk_end],
            A,
            _B,
            _C,
            D,
            z=_z,
            delta_bias=delta_bias,
            delta_softplus=delta_softplus,
            has_initial_state=torch.ones(batch_size,
                                         device=u.device,
                                         dtype=torch.bool) if c > 0 else None)
318
319
320
        outs.append(out)
    if len(outs) > 1:
        out = torch.cat(outs, dim=-1)
321
322
323
324
325
326
327
328
329
330
331
332

    out_ref, state_ref, *rest = selective_scan_ref(
        u_ref,
        delta_ref,
        A_ref,
        B_ref,
        C_ref,
        D_ref,
        z=z_ref,
        delta_bias=delta_bias,
        delta_softplus=delta_softplus,
        return_last_state=True)
333
334
335

    assert out is not None and out_ref is not None
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
336
337
    assert state is not None and state_ref is not None
    assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
338

339
340
341
342
343
344
    selective_scan_opcheck_fn(u,
                              delta,
                              A,
                              B,
                              C,
                              D,
345
                              z,
346
347
                              delta_bias=delta_bias,
                              delta_softplus=delta_softplus,
348
                              ssm_states=state)
349

350
351
352
353
354
355
356
357
358
359
360
361
362
363

@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
        if torch.version.hip:
            atol *= 2
    # set seed
364
    current_platform.seed_everything(0)
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    batch_size = 1
    state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
    x = torch.randn(batch_size, dim, device=device, dtype=itype)
    dt = torch.randn(batch_size, dim, device=device, dtype=itype)
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
    B = torch.randn(batch_size, dstate, device=device)
    C = torch.randn(batch_size, dstate, device=device)
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state.detach().clone()
    out = selective_state_update(state,
                                 x,
                                 dt,
                                 A,
                                 B,
                                 C,
                                 D=D,
                                 z=z,
                                 dt_bias=dt_bias,
                                 dt_softplus=True)
    out_ref = selective_state_update_ref(state_ref,
                                         x,
                                         dt,
                                         A,
                                         B,
                                         C,
                                         D=D,
                                         z=z,
                                         dt_bias=dt_bias,
                                         dt_softplus=True)

    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
399
400


401
402
403
404
405
406
407
408
409
410
411
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
412
413
414
415
416
417
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
                               varBC_groups, has_D, has_z, has_delta_bias,
                               delta_softplus, return_last_state, seqlen,
                               itype, wtype):
418
419
420
421
422
423
424
425
426
427
428
429
430
    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
        pytest.skip()  # This config is not applicable
    device = 'cuda'
    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    rtolw, atolw = (1e-3, 1e-3)
    if has_z:  # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
    # set seed
    torch.random.manual_seed(0)
    seqlens = []
431
    batch_size = 4
432
    if seqlen < 10:
433
434
435
436
437
438
439
440
        batch_size = 1
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding

    if with_padding and seqlen < padded_batch_size:
        pytest.skip()

    nsplits = padded_batch_size - 1
441
442
443
444
445
446
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
    seqlens.append(
        torch.diff(
            torch.cat(
                [torch.tensor([-1]), eos_pos,
                 torch.tensor([seqlen - 1])])).tolist())
447

448
449
450
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

451
    total_entries = batch_size * 10
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
                          dim=0).cuda()

    dim = 4
    dstate = 8
    A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
    A_ref = A.clone()
    B_shape = [varBC_groups, dstate, seqlen]
    B = torch.randn(B_shape,
                    device=device,
                    dtype=wtype if not is_variable_B else itype)
    B_ref = B.clone()
    C_shape = [varBC_groups, dstate, seqlen]
    C = torch.randn(C_shape,
                    device=device,
                    dtype=wtype if not is_variable_C else itype)
    C_ref = C.clone()
    D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
    D_ref = D.clone()
    z = torch.randn(dim, seqlen, device=device, dtype=itype)
    z_ref = z.clone()
    delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
                  ) if has_delta_bias else None
    u = torch.randn(dim, seqlen, device=device, dtype=itype)
    u_ref = u.clone()
    delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
    delta_ref = delta.clone()
    out = None
    out_ref = None
482
483

    prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
484
485
486
487
488
    prev_state = torch.randn(prev_state_shape,
                             device=u.device,
                             dtype=itype,
                             requires_grad=False)
    prev_state_ref = prev_state.clone()
489
    state_indices = torch.randperm(total_entries,
490
                                   dtype=torch.int32,
491
492
493
494
495
496
497
498
499
500
501
                                   device=u.device)[:batch_size]
    unused_states_bool = torch.ones(total_entries,
                                    dtype=torch.bool,
                                    device=device)
    unused_states_bool[state_indices] = False
    padded_state_indices = torch.concat([
        state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
    ],
                                        dim=-1)
502
503
504
505
506
507

    has_initial_state = torch.randint(0,
                                      2, (cumsum.shape[0] - 1, ),
                                      dtype=torch.bool,
                                      device=u.device)
    out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
508
                            delta_softplus, cumsum, padded_state_indices,
509
510
511
512
513
514
515
                            has_initial_state)
    outs_ref = []
    splits = [
        torch.split(var, seqlens[0], dim=-1)
        for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
    ]
    for i in range(len(seqlens[0])):
516
        u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
517
518
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
519
520
521
522
523
524
525
526
527
528
529
        out_ref_s, _ = selective_scan_ref(
            u_s,
            delta_s,
            A_ref,
            B_s,
            C_s,
            D_ref,
            z=z_s,
            delta_bias=delta_bias,
            delta_softplus=delta_softplus,
            return_last_state=return_last_state,
530
            prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
531
            if has_initial_state[i] else None,
532
533
            final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
                0))
534
        outs_ref.append(out_ref_s)
535
    out_ref = torch.cat(outs_ref, dim=-1)[0]
536

537
538
539
    unpadded_out = out[:, :out_ref[0].shape[-1]]
    print("Output diff max", (unpadded_out - out_ref).max())
    print("Output diff mean", (unpadded_out - out_ref).mean())
540
541
542
    print("Output state diff max", (prev_state - prev_state_ref).max())
    print("Output state diff mean", (prev_state - prev_state_ref).mean())
    assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
543
    assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
544
    selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
545
                              delta_softplus, cumsum, padded_state_indices,
546
547
548
                              has_initial_state, prev_state)


549
550
@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
551
@pytest.mark.parametrize("has_z", [True])
552
553
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
554
555
556
557
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
                                                   has_z, itype):
558
559
560
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
561
        rtol, atol = 1e-1, 1e-1
562
563
564
565
        if torch.version.hip:
            atol *= 2
    # set seed
    torch.random.manual_seed(0)
566
    batch_size = 3
567
568
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
569
570
571
572
    total_entries = 10 * batch_size
    state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
    state_indices = torch.randperm(total_entries)[:batch_size].to(
        dtype=torch.int32, device=device)
573
574
575
576
577
578
579
580
581
582
583
584
    unused_states_bool = torch.ones(total_entries,
                                    dtype=torch.bool,
                                    device=device)
    unused_states_bool[state_indices] = False
    padded_state_indices = torch.concat([
        state_indices,
        torch.as_tensor(
            [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
    ],
                                        dim=0)
    x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
    dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
585
586
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
587
588
    B = torch.randn(padded_batch_size, dstate, device=device)
    C = torch.randn(padded_batch_size, dstate, device=device)
589
590
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
591
592
    state_ref = state[state_indices, :].clone()
    state_before = state.clone()
593
594
595
596
597
598
599
600
601
602
    out = selective_state_update(state,
                                 x,
                                 dt,
                                 A,
                                 B,
                                 C,
                                 D=D,
                                 z=z,
                                 dt_bias=dt_bias,
                                 dt_softplus=True,
603
604
                                 state_batch_indices=padded_state_indices,
                                 pad_slot_id=PAD_SLOT_ID)
605
    out_ref = selective_state_update_ref(state_ref,
606
607
                                         x[:batch_size],
                                         dt[:batch_size],
608
                                         A,
609
610
                                         B[:batch_size],
                                         C[:batch_size],
611
                                         D=D,
612
                                         z=z[:batch_size],
613
614
615
                                         dt_bias=dt_bias,
                                         dt_softplus=True)

616
617
    print("Output diff max", (out[:batch_size] - out_ref).max())
    print("Output diff mean", (out[:batch_size] - out_ref).mean())
618
619
620
    print("Output state diff max", (state[state_indices, :] - state_ref).max())
    print("Output state diff mean",
          (state[state_indices, :] - state_ref).mean())
621
622
623
624
625
626
627
628
629
630
    # test padded entries stay the same
    if with_padding:
        assert torch.equal(state_before[unused_states_bool],
                           state[unused_states_bool])
        assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
        assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
        assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
        assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])

    # test "real" entries
631
632
633
634
    assert torch.allclose(state[state_indices, :],
                          state_ref,
                          rtol=rtol,
                          atol=atol)
635
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652


@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
        dim, dstate, ngroups, has_z, tie_hdim, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
    if itype == torch.bfloat16:
        rtol, atol = 1e-1, 1e-1
    # set seed
    torch.random.manual_seed(0)
653
    batch_size = 3
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
    headdim = 64
    nheads = dim // headdim

    total_entries = 10 * batch_size
    state = torch.randn(total_entries,
                        nheads,
                        headdim,
                        dstate,
                        dtype=itype,
                        device=device)
    state_indices = torch.randperm(total_entries)[:batch_size].to(
        dtype=torch.int32, device=device)

    x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
    if not tie_hdim:
        dt = torch.randn(batch_size,
                         nheads,
                         headdim,
                         device=device,
                         dtype=itype)
        dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
        A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
        D = torch.randn(nheads, headdim, device=device)
    else:
        dt = repeat(torch.randn(batch_size, nheads, device=device,
                                dtype=itype),
                    "b h -> b h p",
                    p=headdim)
        dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
                         "h -> h p",
                         p=headdim)
        A = repeat(-torch.rand(nheads, device=device) - 1.0,
                   "h -> h p n",
                   p=headdim,
                   n=dstate)
        D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
    B = torch.randn(batch_size, ngroups, dstate, device=device)
    C = torch.randn(batch_size, ngroups, dstate, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state[state_indices, :].detach().clone()
    out = selective_state_update(state,
                                 x,
                                 dt,
                                 A,
                                 B,
                                 C,
                                 D=D,
                                 z=z,
                                 dt_bias=dt_bias,
                                 dt_softplus=True,
704
705
                                 state_batch_indices=state_indices,
                                 pad_slot_id=PAD_SLOT_ID)
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
    out_ref = selective_state_update_ref(state_ref,
                                         x,
                                         dt,
                                         A,
                                         B,
                                         C,
                                         D=D,
                                         z=z,
                                         dt_bias=dt_bias,
                                         dt_softplus=True)

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    assert torch.allclose(state[state_indices, :],
                          state_ref,
                          rtol=rtol,
                          atol=atol)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)