"fern/pages/components/frontend/frontend-guide.md" did not exist on "8c8680b179c1e7fc13fb8c79100386561dc5dd99"
test_mamba_ssm.py 36.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

9
10
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401
11
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
12
13
14
    selective_scan_fn,
    selective_state_update,
)
15
from vllm.utils.torch_utils import set_random_seed
16
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
17
18


19
20
21
def selective_state_update_ref(
    state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    """
    Argument:
        state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
        x: (batch, dim) or (batch, nheads, dim)
        dt: (batch, dim) or (batch, nheads, dim)
        A: (dim, dstate) or (nheads, dim, dstate)
        B: (batch, dstate) or (batch, ngroups, dstate)
        C: (batch, dstate) or (batch, ngroups, dstate)
        D: (dim,) or (nheads, dim)
        z: (batch, dim) or (batch, nheads, dim)
        dt_bias: (dim,) or (nheads, dim)
    Return:
        out: (batch, dim) or (batch, nheads, dim)
    """
    has_heads = state.dim() > 3
    if state.dim() == 3:
        state = state.unsqueeze(1)
    if x.dim() == 2:
        x = x.unsqueeze(1)
    if dt.dim() == 2:
        dt = dt.unsqueeze(1)
    if A.dim() == 2:
        A = A.unsqueeze(0)
    if B.dim() == 2:
        B = B.unsqueeze(1)
    if C.dim() == 2:
        C = C.unsqueeze(1)
    if D is not None and D.dim() == 1:
        D = D.unsqueeze(0)
    if z is not None and z.dim() == 2:
        z = z.unsqueeze(1)
    if dt_bias is not None and dt_bias.dim() == 1:
        dt_bias = dt_bias.unsqueeze(0)
    batch, nheads, dim, dstate = state.shape
    assert x.shape == (batch, nheads, dim)
    assert dt.shape == x.shape
    assert A.shape == (nheads, dim, dstate)
    ngroups = B.shape[1]
    assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
    assert B.shape == (batch, ngroups, dstate)
    assert C.shape == B.shape
    if D is not None:
        assert D.shape == (nheads, dim)
    if z is not None:
        assert z.shape == x.shape
    if dt_bias is not None:
        assert dt_bias.shape == (nheads, dim)
        dt = dt + dt_bias
    dt = F.softplus(dt) if dt_softplus else dt
71
72
73
74
75
    dA = torch.exp(
        rearrange(dt, "b h d -> b h d 1") * A
    )  # (batch, nheads, dim, dstate)
    B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups)  # (batch, nheads, dstate)
    C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups)  # (batch, nheads, dstate)
76
    dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
77
78
79
80
81
        B, "b h n -> b h 1 n"
    )  # (batch, nheads, dim, dstate)
    state.copy_(
        state * dA + dB * rearrange(x, "b h d -> b h d 1")
    )  # (batch, dim, dstate
82
83
84
85
86
87
88
89
90
    out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
    if D is not None:
        out += (x * D).to(out.dtype)
    out = (out if z is None else out * F.silu(z)).to(x.dtype)
    if not has_heads:
        out = out.squeeze(1)
    return out


91
92
93
94
95
96
97
98
99
100
101
102
103
104
def selective_scan_ref(
    u,
    delta,
    A,
    B,
    C,
    D=None,
    z=None,
    delta_bias=None,
    delta_softplus=False,
    return_last_state=False,
    prev_state=None,
    final_state_out=None,
):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
    prev_state: r(B D N), fp32

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    B = B.float()
    C = C.float()
    x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
    ys = []
133
    deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
134
    if not is_variable_B:
135
        deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
136
137
    else:
        if B.dim() == 3:
138
            deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
139
140
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
141
            deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
142
143
144
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    for i in range(u.shape[2]):
145
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
146
        if not is_variable_C:
147
            y = torch.einsum("bdn,dn->bd", x, C)
148
149
        else:
            if C.dim() == 3:
150
                y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
151
            else:
152
                y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
153
        if i == u.shape[2] - 1:
154
155
156
157
            if final_state_out is None:
                final_state_out = x
            else:
                final_state_out.copy_(x)
158
159
160
161
162
163
        ys.append(y)
    y = torch.stack(ys, dim=2)  # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
164
    return out if not return_last_state else (out, final_state_out)
165
166


167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def selective_scan_opcheck_fn(
    u,
    delta,
    A,
    B,
    C,
    D=None,
    z=None,
    delta_bias=None,
    delta_softplus=False,
    cu_seq_len=None,
    cache_indices=None,
    has_initial_state=None,
    ssm_states=None,
    pad_slot_id=PAD_SLOT_ID,
182
183
184
185
    block_size=2048,
    block_idx_first_scheduled_token=None,
    block_idx_last_scheduled_token=None,
    initial_state_idx=None,
186
187
    cu_chunk_seqlen=None,
    last_chunk_indices=None,
188
):
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate).
    """
    if u.stride(-1) != 1:
        u = u.contiguous()
    if delta.stride(-1) != 1:
        delta = delta.contiguous()
    if D is not None:
        D = D.contiguous()
    if B.stride(-1) != 1:
        B = B.contiguous()
    if C.stride(-1) != 1:
        C = C.contiguous()
    if z is not None and z.stride(-1) != 1:
        z = z.contiguous()
204
    if B.dim() == 3 and cu_seq_len is None:
205
        B = B.unsqueeze(1)
206
207
208
    if B.dim() == 2 and cu_seq_len is not None:
        B = B.unsqueeze(0)
    if C.dim() == 3 and cu_seq_len is None:
209
        C = C.unsqueeze(1)
210
211
    if C.dim() == 2 and cu_seq_len is not None:
        C = C.unsqueeze(0)
212
213
214

    # Disable test_autograd_registration for now as it seems to trigger
    # a bogus error.
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    opcheck(
        torch.ops._C.selective_scan_fwd,
        (
            u,
            delta,
            A,
            B,
            C,
            D,
            z,
            delta_bias,
            delta_softplus,
            cu_seq_len,
            cache_indices,
            has_initial_state,
            ssm_states,
            pad_slot_id,
232
233
234
235
            block_size,
            block_idx_first_scheduled_token,
            block_idx_last_scheduled_token,
            initial_state_idx,
236
237
            cu_chunk_seqlen,
            last_chunk_indices,
238
239
240
241
242
243
        ),
        test_utils=["test_schema", "test_faketensor"],
    )


@pytest.mark.parametrize("wtype", [torch.float32])
244
245
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("seqlen", [128, 1024, 4096])
246
247
248
249
@pytest.mark.parametrize("has_delta_bias", [True])
@pytest.mark.parametrize("delta_softplus", [True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_D", [True])
250
251
252
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
253
@pytest.mark.parametrize("scan_chunks", [1, 3])
254
255
256
257
258
259
260
261
262
263
264
265
266
def test_selective_scan(
    is_variable_B,
    is_variable_C,
    varBC_groups,
    has_D,
    has_z,
    has_delta_bias,
    delta_softplus,
    seqlen,
    itype,
    wtype,
    scan_chunks,
):
267
268
    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
        pytest.skip()  # This config is not applicable
269
    device = "cuda"
270
271
272
273
274
275
276
277
    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    rtolw, atolw = (1e-3, 1e-3)
    if has_z:  # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
    # set seed
278
    set_random_seed(0)
279
    batch_size = 1
280
281
    dim = 4
    dstate = 8
282
    A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
283
    A_ref = A.clone()
284
285
286
287
288
289
    if not is_variable_B:
        B_shape = [dim, dstate]
    elif varBC_groups == 1:
        B_shape = [batch_size, dstate, seqlen]
    else:
        B_shape = [batch_size, varBC_groups, dstate, seqlen]
290
    B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
291
    B_ref = B.clone()
292
293
294
295
296
297
    if not is_variable_C:
        C_shape = [dim, dstate]
    elif varBC_groups == 1:
        C_shape = [batch_size, dstate, seqlen]
    else:
        C_shape = [batch_size, varBC_groups, dstate, seqlen]
298
    C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
299
    C_ref = C.clone()
300
    D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
301
    D_ref = D.clone()
302
303
304
305
306
    z = (
        torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
        if has_z
        else None
    )
307
    z_ref = z.clone() if has_z else None
308
309
310
311
312
    delta_bias = (
        (0.5 * torch.rand(dim, device=device, dtype=torch.float32))
        if has_delta_bias
        else None
    )
313
    u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
314
    u_ref = u.clone()
315
    delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)
316
317
    delta_ref = delta.clone()
    state_shape = (batch_size, u.shape[1], int(A.shape[1]))
318
    state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False)
319
    state_ref = state.clone()
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    out = None
    out_ref = None
    outs = []
    for c in range(scan_chunks):
        chunked_prompt_len = seqlen // scan_chunks
        chunk_start = chunked_prompt_len * c
        chunk_end = chunked_prompt_len * (c + 1)
        if c == scan_chunks - 1:
            chunk_end = seqlen
        _B = B
        if is_variable_B:
            _B = B[..., chunk_start:chunk_end]
        _C = C
        if is_variable_B:
            _C = C[..., chunk_start:chunk_end]
        _z = z
        if has_z:
            assert z is not None
            _z = z[..., chunk_start:chunk_end]
339
340
341
342
343
344
345
346
347
348
349
        out = selective_scan_fn(
            u[..., chunk_start:chunk_end],
            state,
            delta[..., chunk_start:chunk_end],
            A,
            _B,
            _C,
            D,
            z=_z,
            delta_bias=delta_bias,
            delta_softplus=delta_softplus,
350
351
352
            has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
            if c > 0
            else None,
353
354
355
356
357
            pad_slot_id=PAD_SLOT_ID,
            block_size=2048,
            block_idx_first_scheduled_token=None,
            block_idx_last_scheduled_token=None,
            initial_state_idx=None,
358
        )
359
360
361
        outs.append(out)
    if len(outs) > 1:
        out = torch.cat(outs, dim=-1)
362
363
364
365
366
367
368
369
370
371
372

    out_ref, state_ref, *rest = selective_scan_ref(
        u_ref,
        delta_ref,
        A_ref,
        B_ref,
        C_ref,
        D_ref,
        z=z_ref,
        delta_bias=delta_bias,
        delta_softplus=delta_softplus,
373
374
        return_last_state=True,
    )
375
376
377

    assert out is not None and out_ref is not None
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
378
379
    assert state is not None and state_ref is not None
    assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
380

381
382
383
384
385
386
387
388
389
390
391
    selective_scan_opcheck_fn(
        u,
        delta,
        A,
        B,
        C,
        D,
        z,
        delta_bias=delta_bias,
        delta_softplus=delta_softplus,
        ssm_states=state,
392
        block_size=2048,
393
    )
394

395

396
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
397
@pytest.mark.parametrize("has_z", [False, True])
398
@pytest.mark.parametrize("dstate", [16, 64])
399
400
401
402
403
404
405
406
407
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
        if torch.version.hip:
            atol *= 2
    # set seed
408
    set_random_seed(0)
409
410
411
    batch_size = 1
    state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
    x = torch.randn(batch_size, dim, device=device, dtype=itype)
412
    out = torch.empty_like(x)
413
414
415
416
417
418
419
420
    dt = torch.randn(batch_size, dim, device=device, dtype=itype)
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
    B = torch.randn(batch_size, dstate, device=device)
    C = torch.randn(batch_size, dstate, device=device)
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state.detach().clone()
421
422
423
424
425
426
    selective_state_update(
        state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
    )
    out_ref = selective_state_update_ref(
        state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
    )
427
428
429

    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
430
431


432
433
434
435
436
437
438
439
440
441
442
443
444
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("max_seq_len", [1, 2, 4])
def test_selective_state_update_varlen(dim, dstate, has_z, itype, max_seq_len):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
        rtol, atol = 5e-2, 1.5e-1
        if torch.version.hip:
            atol *= 2
    # set seed
445
    set_random_seed(0)
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    batch_size = 4
    token_counts = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
    total_tokens = int(token_counts.sum().item())
    cu_seqlens = torch.tensor(
        [0] + torch.cumsum(token_counts, dim=0).tolist(),
        dtype=torch.int32,
        device=device,
    )
    state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
    x = torch.randn(total_tokens, dim, device=device, dtype=itype)
    out = torch.empty_like(x)
    dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
    B = torch.randn(total_tokens, dstate, device=device)
    C = torch.randn(total_tokens, dstate, device=device)
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state.detach().clone()
    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        out=out,
        cu_seqlens=cu_seqlens,
    )

    out_ref_list = []
    for seq_idx in range(batch_size):
        start_idx = cu_seqlens[seq_idx].item()
        end_idx = cu_seqlens[seq_idx + 1].item()
        num_tokens = end_idx - start_idx
        for token_idx in range(num_tokens):
            idx = start_idx + token_idx
            out_ref_list.append(
                selective_state_update_ref(
                    state_ref[seq_idx : seq_idx + 1],
                    x[idx : idx + 1],
                    dt[idx : idx + 1],
                    A,
                    B[idx : idx + 1],
                    C[idx : idx + 1],
                    D=D,
                    z=z[idx : idx + 1] if has_z else None,
                    dt_bias=dt_bias,
                    dt_softplus=True,
                )
            )
    out_ref = torch.cat(out_ref_list, dim=0)
    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)


506
507
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32])
508
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
509
@pytest.mark.parametrize("return_last_state", [True])
510
511
512
513
@pytest.mark.parametrize("has_delta_bias", [True])
@pytest.mark.parametrize("delta_softplus", [True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_D", [True])
514
515
516
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
517
518
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
519
520
521
522
523
524
525
526
527
528
529
530
531
532
def test_selective_scan_varlen(
    with_padding,
    is_variable_B,
    is_variable_C,
    varBC_groups,
    has_D,
    has_z,
    has_delta_bias,
    delta_softplus,
    return_last_state,
    seqlen,
    itype,
    wtype,
):
533
534
    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
        pytest.skip()  # This config is not applicable
535
    device = "cuda"
536
537
538
539
540
541
542
543
544
545
    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    rtolw, atolw = (1e-3, 1e-3)
    if has_z:  # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
    # set seed
    torch.random.manual_seed(0)
    seqlens = []
546
    batch_size = 4
547
    if seqlen < 10:
548
549
550
551
552
553
554
555
        batch_size = 1
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding

    if with_padding and seqlen < padded_batch_size:
        pytest.skip()

    nsplits = padded_batch_size - 1
556
557
558
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
    seqlens.append(
        torch.diff(
559
560
561
            torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
        ).tolist()
    )
562

563
564
565
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

566
    total_entries = batch_size * 10
567
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
568
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda()
569
570
571

    dim = 4
    dstate = 8
572
    A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
573
574
    A_ref = A.clone()
    B_shape = [varBC_groups, dstate, seqlen]
575
    B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
576
577
    B_ref = B.clone()
    C_shape = [varBC_groups, dstate, seqlen]
578
    C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
579
580
581
582
583
    C_ref = C.clone()
    D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
    D_ref = D.clone()
    z = torch.randn(dim, seqlen, device=device, dtype=itype)
    z_ref = z.clone()
584
585
586
587
588
    delta_bias = (
        (0.5 * torch.rand(dim, device=device, dtype=torch.float32))
        if has_delta_bias
        else None
    )
589
590
    u = torch.randn(dim, seqlen, device=device, dtype=itype)
    u_ref = u.clone()
591
    delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)
592
593
594
    delta_ref = delta.clone()
    out = None
    out_ref = None
595
596

    prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
597
598
599
    prev_state = torch.randn(
        prev_state_shape, device=u.device, dtype=itype, requires_grad=False
    )
600
    prev_state_ref = prev_state.clone()
601
602
603
604
    state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
        :batch_size
    ]
    unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
605
    unused_states_bool[state_indices] = False
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    padded_state_indices = torch.concat(
        [
            state_indices,
            torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
        ],
        dim=-1,
    )

    has_initial_state = torch.randint(
        0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device
    )
    out = selective_scan_fn(
        u,
        prev_state,
        delta,
        A,
        B,
        C,
        D,
        z,
        delta_bias,
        delta_softplus,
        cumsum,
        padded_state_indices,
        has_initial_state,
    )
632
633
634
635
636
637
    outs_ref = []
    splits = [
        torch.split(var, seqlens[0], dim=-1)
        for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
    ]
    for i in range(len(seqlens[0])):
638
        u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
639
640
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
641
642
643
644
645
646
647
648
649
650
651
        out_ref_s, _ = selective_scan_ref(
            u_s,
            delta_s,
            A_ref,
            B_s,
            C_s,
            D_ref,
            z=z_s,
            delta_bias=delta_bias,
            delta_softplus=delta_softplus,
            return_last_state=return_last_state,
652
            prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
653
654
655
656
            if has_initial_state[i]
            else None,
            final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0),
        )
657
        outs_ref.append(out_ref_s)
658
    out_ref = torch.cat(outs_ref, dim=-1)[0]
659

660
    unpadded_out = out[:, : out_ref[0].shape[-1]]
661
662
    print("Output diff max", (unpadded_out - out_ref).max())
    print("Output diff mean", (unpadded_out - out_ref).mean())
663
664
665
    print("Output state diff max", (prev_state - prev_state_ref).max())
    print("Output state diff mean", (prev_state - prev_state_ref).mean())
    assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
666
    assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    selective_scan_opcheck_fn(
        u,
        delta,
        A,
        B,
        C,
        D,
        z,
        delta_bias,
        delta_softplus,
        cumsum,
        padded_state_indices,
        has_initial_state,
        prev_state,
681
        block_size=2048,
682
683
684
    )


685
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
686
@pytest.mark.parametrize("has_z", [True])
687
@pytest.mark.parametrize("dstate", [16, 64])
688
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
689
690
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
691
692
693
def test_selective_state_update_with_batch_indices(
    with_padding, dim, dstate, has_z, itype
):
694
695
696
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
697
        rtol, atol = 1e-1, 1e-1
698
699
700
701
        if torch.version.hip:
            atol *= 2
    # set seed
    torch.random.manual_seed(0)
702
    batch_size = 3
703
704
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
705
706
707
    total_entries = 10 * batch_size
    state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
    state_indices = torch.randperm(total_entries)[:batch_size].to(
708
709
710
        dtype=torch.int32, device=device
    )
    unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
711
    unused_states_bool[state_indices] = False
712
713
714
715
716
717
718
    padded_state_indices = torch.concat(
        [
            state_indices,
            torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
        ],
        dim=0,
    )
719
    x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
720
    out = torch.empty_like(x)
721
    dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
722
723
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
724
725
    B = torch.randn(padded_batch_size, dstate, device=device)
    C = torch.randn(padded_batch_size, dstate, device=device)
726
727
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
728
729
    state_ref = state[state_indices, :].clone()
    state_before = state.clone()
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        state_batch_indices=padded_state_indices,
        pad_slot_id=PAD_SLOT_ID,
        out=out,
    )
    out_ref = selective_state_update_ref(
        state_ref,
        x[:batch_size],
        dt[:batch_size],
        A,
        B[:batch_size],
        C[:batch_size],
        D=D,
        z=z[:batch_size],
        dt_bias=dt_bias,
        dt_softplus=True,
    )
757

758
759
    print("Output diff max", (out[:batch_size] - out_ref).max())
    print("Output diff mean", (out[:batch_size] - out_ref).mean())
760
    print("Output state diff max", (state[state_indices, :] - state_ref).max())
761
    print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
762
763
    # test padded entries stay the same
    if with_padding:
764
765
766
767
768
        assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
        assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
        assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
        assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
        assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
769
770

    # test "real" entries
771
    assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
772
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
773
774


775
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
776
777
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
778
779
@pytest.mark.parametrize("ngroups", [1, 4])
@pytest.mark.parametrize("dstate", [16, 64])
780
781
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
782
783
    dim, dstate, ngroups, has_z, tie_hdim, itype
):
784
785
786
787
788
789
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
    if itype == torch.bfloat16:
        rtol, atol = 1e-1, 1e-1
    # set seed
    torch.random.manual_seed(0)
790
    batch_size = 3
791
792
793
794
    headdim = 64
    nheads = dim // headdim

    total_entries = 10 * batch_size
795
796
797
    state = torch.randn(
        total_entries, nheads, headdim, dstate, dtype=itype, device=device
    )
798
    state_indices = torch.randperm(total_entries)[:batch_size].to(
799
800
        dtype=torch.int32, device=device
    )
801
802

    x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
803
    out = torch.empty_like(x)
804
    if not tie_hdim:
805
        dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
806
807
808
809
        dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
        A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
        D = torch.randn(nheads, headdim, device=device)
    else:
810
811
812
813
814
815
816
817
818
        dt = repeat(
            torch.randn(batch_size, nheads, device=device, dtype=itype),
            "b h -> b h p",
            p=headdim,
        )
        dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
        A = repeat(
            -torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
        )
819
820
821
822
823
        D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
    B = torch.randn(batch_size, ngroups, dstate, device=device)
    C = torch.randn(batch_size, ngroups, dstate, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state[state_indices, :].detach().clone()
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        state_batch_indices=state_indices,
        pad_slot_id=PAD_SLOT_ID,
        out=out,
    )
    out_ref = selective_state_update_ref(
        state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
    )
842
843
844

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
845
    assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
846
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863


@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_with_num_accepted_tokens(
    dim, dstate, has_z, itype, max_seq_len
):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
        rtol, atol = 5e-2, 1.5e-1
        if torch.version.hip:
            atol *= 2

864
    set_random_seed(0)
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
    batch_size = 4

    tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
    total_tokens = int(tokens_per_seq.sum().item())

    num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
    num_accepted_tokens[0] = 0  # Add edge-case of no accepted tokens
    num_accepted_tokens[1] = max_seq_len  # Add edge-case of all tokens accepted

    cu_seqlens = torch.tensor(
        [0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
        dtype=torch.int32,
        device=device,
    )

    total_state_slots = 50
    state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)

    state_batch_indices = torch.full(
        (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
    )
    initial_state_slots = torch.randint(
        0, 15, (batch_size,), device=device, dtype=torch.int32
    )
    for seq_idx in range(batch_size):
        token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
        state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]

    dst_state_batch_indices = torch.full(
        (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
    )
    slot_offset = 15
    dst_slots_map = {}
    for seq_idx in range(batch_size):
        for token_idx in range(tokens_per_seq[seq_idx].item()):
            dst_state_batch_indices[seq_idx, token_idx] = slot_offset
            dst_slots_map[(seq_idx, token_idx)] = slot_offset
            slot_offset += 1

    x = torch.randn(total_tokens, dim, device=device, dtype=itype)
    out = torch.empty_like(x)
    dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
    B = torch.randn(total_tokens, dstate, device=device)
    C = torch.randn(total_tokens, dstate, device=device)
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None

    state_ref_intermediate = {}
    out_ref_list = []

    for seq_idx in range(batch_size):
        seq_start = cu_seqlens[seq_idx].item()
        seq_end = cu_seqlens[seq_idx + 1].item()
        num_tokens = seq_end - seq_start

        token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
        initial_slot = state_batch_indices[seq_idx, token_pos].item()
        state_seq = state[initial_slot : initial_slot + 1].clone()

        for token_idx in range(num_tokens):
            global_idx = seq_start + token_idx

            out_token = selective_state_update_ref(
                state_seq,
                x[global_idx : global_idx + 1],
                dt[global_idx : global_idx + 1],
                A,
                B[global_idx : global_idx + 1],
                C[global_idx : global_idx + 1],
                D=D,
                z=z[global_idx : global_idx + 1] if has_z else None,
                dt_bias=dt_bias,
                dt_softplus=True,
            )
            out_ref_list.append(out_token)
            state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()

    out_ref = torch.cat(out_ref_list, dim=0)

    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        out=out,
        cu_seqlens=cu_seqlens,
        state_batch_indices=state_batch_indices,
        dst_state_batch_indices=dst_state_batch_indices,
        num_accepted_tokens=num_accepted_tokens,
        pad_slot_id=PAD_SLOT_ID,
    )

    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

    for seq_idx in range(batch_size):
        num_tokens = tokens_per_seq[seq_idx].item()
        for token_idx in range(num_tokens):
            dst_slot = dst_slots_map[(seq_idx, token_idx)]
            state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)
            assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
@pytest.mark.parametrize("max_seq_len", [2, 4])
def test_selective_state_update_varlen_with_num_accepted(
    dim, dstate, has_z, itype, max_seq_len
):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
        rtol, atol = 5e-2, 1.5e-1
        if torch.version.hip:
            atol *= 2

990
    set_random_seed(0)
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    batch_size = 4

    tokens_per_seq = torch.randint(1, max_seq_len + 1, (batch_size,), device=device)
    total_tokens = int(tokens_per_seq.sum().item())

    num_accepted_tokens = torch.randint(0, max_seq_len, (batch_size,), device=device)
    num_accepted_tokens[0] = 0  # Add edge-case of no accepted tokens
    num_accepted_tokens[1] = max_seq_len  # Add edge-case of all tokens accepted

    cu_seqlens = torch.tensor(
        [0] + torch.cumsum(tokens_per_seq, dim=0).tolist(),
        dtype=torch.int32,
        device=device,
    )

    total_state_slots = 50
    state = torch.randn(total_state_slots, dim, dstate, dtype=itype, device=device)

    state_batch_indices = torch.full(
        (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
    )

    initial_state_slots = torch.randint(
        0, 15, (batch_size,), device=device, dtype=torch.int32
    )
    for seq_idx in range(batch_size):
        token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
        state_batch_indices[seq_idx, token_pos] = initial_state_slots[seq_idx]

    dst_state_batch_indices = torch.full(
        (batch_size, max_seq_len), PAD_SLOT_ID, dtype=torch.int32, device=device
    )

    slot_offset = 15
    dst_slots_map = {}
    for seq_idx in range(batch_size):
        for token_idx in range(tokens_per_seq[seq_idx].item()):
            dst_state_batch_indices[seq_idx, token_idx] = slot_offset
            dst_slots_map[(seq_idx, token_idx)] = slot_offset
            slot_offset += 1

    x = torch.randn(total_tokens, dim, device=device, dtype=itype)
    out = torch.empty_like(x)
    dt = torch.randn(total_tokens, dim, device=device, dtype=itype)
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
    B = torch.randn(total_tokens, dstate, device=device)
    C = torch.randn(total_tokens, dstate, device=device)
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None

    state_ref_intermediate = {}

    for seq_idx in range(batch_size):
        seq_start = cu_seqlens[seq_idx].item()
        seq_end = cu_seqlens[seq_idx + 1].item()
        num_tokens = seq_end - seq_start

        token_pos = max(num_accepted_tokens[seq_idx].item() - 1, 0)
        initial_slot = state_batch_indices[seq_idx, token_pos].item()
        state_seq = state[initial_slot : initial_slot + 1].clone()

        for token_idx in range(num_tokens):
            global_idx = seq_start + token_idx

            selective_state_update_ref(
                state_seq,
                x[global_idx : global_idx + 1],
                dt[global_idx : global_idx + 1],
                A,
                B[global_idx : global_idx + 1],
                C[global_idx : global_idx + 1],
                D=D,
                z=z[global_idx : global_idx + 1] if has_z else None,
                dt_bias=dt_bias,
                dt_softplus=True,
            )

            state_ref_intermediate[(seq_idx, token_idx)] = state_seq.clone()

    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        out=out,
        cu_seqlens=cu_seqlens,
        state_batch_indices=state_batch_indices,
        dst_state_batch_indices=dst_state_batch_indices,
        num_accepted_tokens=num_accepted_tokens,
        pad_slot_id=PAD_SLOT_ID,
    )

    for seq_idx in range(batch_size):
        num_tokens = tokens_per_seq[seq_idx].item()

        for token_idx in range(num_tokens):
            dst_slot = dst_slots_map[(seq_idx, token_idx)]
            state_ref = state_ref_intermediate[(seq_idx, token_idx)].squeeze(0)

            assert torch.allclose(state[dst_slot], state_ref, rtol=rtol, atol=atol)