test_mamba_ssm.py 24.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

9
10
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops  # noqa: F401
11
from vllm.attention.backends.utils import PAD_SLOT_ID
12
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
13
14
15
    selective_scan_fn,
    selective_state_update,
)
16
from vllm.platforms import current_platform
17
18


19
20
21
def selective_state_update_ref(
    state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    """
    Argument:
        state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
        x: (batch, dim) or (batch, nheads, dim)
        dt: (batch, dim) or (batch, nheads, dim)
        A: (dim, dstate) or (nheads, dim, dstate)
        B: (batch, dstate) or (batch, ngroups, dstate)
        C: (batch, dstate) or (batch, ngroups, dstate)
        D: (dim,) or (nheads, dim)
        z: (batch, dim) or (batch, nheads, dim)
        dt_bias: (dim,) or (nheads, dim)
    Return:
        out: (batch, dim) or (batch, nheads, dim)
    """
    has_heads = state.dim() > 3
    if state.dim() == 3:
        state = state.unsqueeze(1)
    if x.dim() == 2:
        x = x.unsqueeze(1)
    if dt.dim() == 2:
        dt = dt.unsqueeze(1)
    if A.dim() == 2:
        A = A.unsqueeze(0)
    if B.dim() == 2:
        B = B.unsqueeze(1)
    if C.dim() == 2:
        C = C.unsqueeze(1)
    if D is not None and D.dim() == 1:
        D = D.unsqueeze(0)
    if z is not None and z.dim() == 2:
        z = z.unsqueeze(1)
    if dt_bias is not None and dt_bias.dim() == 1:
        dt_bias = dt_bias.unsqueeze(0)
    batch, nheads, dim, dstate = state.shape
    assert x.shape == (batch, nheads, dim)
    assert dt.shape == x.shape
    assert A.shape == (nheads, dim, dstate)
    ngroups = B.shape[1]
    assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
    assert B.shape == (batch, ngroups, dstate)
    assert C.shape == B.shape
    if D is not None:
        assert D.shape == (nheads, dim)
    if z is not None:
        assert z.shape == x.shape
    if dt_bias is not None:
        assert dt_bias.shape == (nheads, dim)
        dt = dt + dt_bias
    dt = F.softplus(dt) if dt_softplus else dt
71
72
73
74
75
    dA = torch.exp(
        rearrange(dt, "b h d -> b h d 1") * A
    )  # (batch, nheads, dim, dstate)
    B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups)  # (batch, nheads, dstate)
    C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups)  # (batch, nheads, dstate)
76
    dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
77
78
79
80
81
        B, "b h n -> b h 1 n"
    )  # (batch, nheads, dim, dstate)
    state.copy_(
        state * dA + dB * rearrange(x, "b h d -> b h d 1")
    )  # (batch, dim, dstate
82
83
84
85
86
87
88
89
90
    out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
    if D is not None:
        out += (x * D).to(out.dtype)
    out = (out if z is None else out * F.silu(z)).to(x.dtype)
    if not has_heads:
        out = out.squeeze(1)
    return out


91
92
93
94
95
96
97
98
99
100
101
102
103
104
def selective_scan_ref(
    u,
    delta,
    A,
    B,
    C,
    D=None,
    z=None,
    delta_bias=None,
    delta_softplus=False,
    return_last_state=False,
    prev_state=None,
    final_state_out=None,
):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
    prev_state: r(B D N), fp32

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    B = B.float()
    C = C.float()
    x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
    ys = []
133
    deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
134
    if not is_variable_B:
135
        deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
136
137
    else:
        if B.dim() == 3:
138
            deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
139
140
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
141
            deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
142
143
144
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    for i in range(u.shape[2]):
145
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
146
        if not is_variable_C:
147
            y = torch.einsum("bdn,dn->bd", x, C)
148
149
        else:
            if C.dim() == 3:
150
                y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
151
            else:
152
                y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
153
        if i == u.shape[2] - 1:
154
155
156
157
            if final_state_out is None:
                final_state_out = x
            else:
                final_state_out.copy_(x)
158
159
160
161
162
163
        ys.append(y)
    y = torch.stack(ys, dim=2)  # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
164
    return out if not return_last_state else (out, final_state_out)
165
166


167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def selective_scan_opcheck_fn(
    u,
    delta,
    A,
    B,
    C,
    D=None,
    z=None,
    delta_bias=None,
    delta_softplus=False,
    cu_seq_len=None,
    cache_indices=None,
    has_initial_state=None,
    ssm_states=None,
    pad_slot_id=PAD_SLOT_ID,
):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate).
    """
    if u.stride(-1) != 1:
        u = u.contiguous()
    if delta.stride(-1) != 1:
        delta = delta.contiguous()
    if D is not None:
        D = D.contiguous()
    if B.stride(-1) != 1:
        B = B.contiguous()
    if C.stride(-1) != 1:
        C = C.contiguous()
    if z is not None and z.stride(-1) != 1:
        z = z.contiguous()
198
    if B.dim() == 3 and cu_seq_len is None:
199
        B = B.unsqueeze(1)
200
201
202
    if B.dim() == 2 and cu_seq_len is not None:
        B = B.unsqueeze(0)
    if C.dim() == 3 and cu_seq_len is None:
203
        C = C.unsqueeze(1)
204
205
    if C.dim() == 2 and cu_seq_len is not None:
        C = C.unsqueeze(0)
206
207
208

    # Disable test_autograd_registration for now as it seems to trigger
    # a bogus error.
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    opcheck(
        torch.ops._C.selective_scan_fwd,
        (
            u,
            delta,
            A,
            B,
            C,
            D,
            z,
            delta_bias,
            delta_softplus,
            cu_seq_len,
            cache_indices,
            has_initial_state,
            ssm_states,
            pad_slot_id,
        ),
        test_utils=["test_schema", "test_faketensor"],
    )


@pytest.mark.parametrize("wtype", [torch.float32])
232
233
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("seqlen", [128, 1024, 4096])
234
235
236
237
@pytest.mark.parametrize("has_delta_bias", [True])
@pytest.mark.parametrize("delta_softplus", [True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_D", [True])
238
239
240
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
241
@pytest.mark.parametrize("scan_chunks", [1, 3])
242
243
244
245
246
247
248
249
250
251
252
253
254
def test_selective_scan(
    is_variable_B,
    is_variable_C,
    varBC_groups,
    has_D,
    has_z,
    has_delta_bias,
    delta_softplus,
    seqlen,
    itype,
    wtype,
    scan_chunks,
):
255
256
    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
        pytest.skip()  # This config is not applicable
257
    device = "cuda"
258
259
260
261
262
263
264
265
    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    rtolw, atolw = (1e-3, 1e-3)
    if has_z:  # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
    # set seed
266
    current_platform.seed_everything(0)
267
    batch_size = 1
268
269
    dim = 4
    dstate = 8
270
    A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
271
    A_ref = A.clone()
272
273
274
275
276
277
    if not is_variable_B:
        B_shape = [dim, dstate]
    elif varBC_groups == 1:
        B_shape = [batch_size, dstate, seqlen]
    else:
        B_shape = [batch_size, varBC_groups, dstate, seqlen]
278
    B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
279
    B_ref = B.clone()
280
281
282
283
284
285
    if not is_variable_C:
        C_shape = [dim, dstate]
    elif varBC_groups == 1:
        C_shape = [batch_size, dstate, seqlen]
    else:
        C_shape = [batch_size, varBC_groups, dstate, seqlen]
286
    C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
287
    C_ref = C.clone()
288
    D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
289
    D_ref = D.clone()
290
291
292
293
294
    z = (
        torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
        if has_z
        else None
    )
295
    z_ref = z.clone() if has_z else None
296
297
298
299
300
    delta_bias = (
        (0.5 * torch.rand(dim, device=device, dtype=torch.float32))
        if has_delta_bias
        else None
    )
301
    u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
302
    u_ref = u.clone()
303
    delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)
304
305
    delta_ref = delta.clone()
    state_shape = (batch_size, u.shape[1], int(A.shape[1]))
306
    state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False)
307
    state_ref = state.clone()
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    out = None
    out_ref = None
    outs = []
    for c in range(scan_chunks):
        chunked_prompt_len = seqlen // scan_chunks
        chunk_start = chunked_prompt_len * c
        chunk_end = chunked_prompt_len * (c + 1)
        if c == scan_chunks - 1:
            chunk_end = seqlen
        _B = B
        if is_variable_B:
            _B = B[..., chunk_start:chunk_end]
        _C = C
        if is_variable_B:
            _C = C[..., chunk_start:chunk_end]
        _z = z
        if has_z:
            assert z is not None
            _z = z[..., chunk_start:chunk_end]
327
328
329
330
331
332
333
334
335
336
337
        out = selective_scan_fn(
            u[..., chunk_start:chunk_end],
            state,
            delta[..., chunk_start:chunk_end],
            A,
            _B,
            _C,
            D,
            z=_z,
            delta_bias=delta_bias,
            delta_softplus=delta_softplus,
338
339
340
341
            has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool)
            if c > 0
            else None,
        )
342
343
344
        outs.append(out)
    if len(outs) > 1:
        out = torch.cat(outs, dim=-1)
345
346
347
348
349
350
351
352
353
354
355

    out_ref, state_ref, *rest = selective_scan_ref(
        u_ref,
        delta_ref,
        A_ref,
        B_ref,
        C_ref,
        D_ref,
        z=z_ref,
        delta_bias=delta_bias,
        delta_softplus=delta_softplus,
356
357
        return_last_state=True,
    )
358
359
360

    assert out is not None and out_ref is not None
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
361
362
    assert state is not None and state_ref is not None
    assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
363

364
365
366
367
368
369
370
371
372
373
374
375
    selective_scan_opcheck_fn(
        u,
        delta,
        A,
        B,
        C,
        D,
        z,
        delta_bias=delta_bias,
        delta_softplus=delta_softplus,
        ssm_states=state,
    )
376

377

378
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
379
@pytest.mark.parametrize("has_z", [False, True])
380
@pytest.mark.parametrize("dstate", [16, 64])
381
382
383
384
385
386
387
388
389
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
        rtol, atol = 1e-2, 5e-2
        if torch.version.hip:
            atol *= 2
    # set seed
390
    current_platform.seed_everything(0)
391
392
393
    batch_size = 1
    state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
    x = torch.randn(batch_size, dim, device=device, dtype=itype)
394
    out = torch.empty_like(x)
395
396
397
398
399
400
401
402
    dt = torch.randn(batch_size, dim, device=device, dtype=itype)
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
    B = torch.randn(batch_size, dstate, device=device)
    C = torch.randn(batch_size, dstate, device=device)
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state.detach().clone()
403
404
405
406
407
408
    selective_state_update(
        state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
    )
    out_ref = selective_state_update_ref(
        state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
    )
409
410
411

    assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
412
413


414
415
@pytest.mark.parametrize("wtype", [torch.float32])
@pytest.mark.parametrize("itype", [torch.float32])
416
@pytest.mark.parametrize("seqlen", [1, 256, 1024, 4096])
417
@pytest.mark.parametrize("return_last_state", [True])
418
419
420
421
@pytest.mark.parametrize("has_delta_bias", [True])
@pytest.mark.parametrize("delta_softplus", [True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("has_D", [True])
422
423
424
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
425
426
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def test_selective_scan_varlen(
    with_padding,
    is_variable_B,
    is_variable_C,
    varBC_groups,
    has_D,
    has_z,
    has_delta_bias,
    delta_softplus,
    return_last_state,
    seqlen,
    itype,
    wtype,
):
441
442
    if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
        pytest.skip()  # This config is not applicable
443
    device = "cuda"
444
445
446
447
448
449
450
451
452
453
    rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
    if itype == torch.bfloat16:
        rtol, atol = 3e-2, 5e-2
    rtolw, atolw = (1e-3, 1e-3)
    if has_z:  # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
    # set seed
    torch.random.manual_seed(0)
    seqlens = []
454
    batch_size = 4
455
    if seqlen < 10:
456
457
458
459
460
461
462
463
        batch_size = 1
    padding = 3 if with_padding else 0
    padded_batch_size = batch_size + padding

    if with_padding and seqlen < padded_batch_size:
        pytest.skip()

    nsplits = padded_batch_size - 1
464
465
466
    eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
    seqlens.append(
        torch.diff(
467
468
469
            torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
        ).tolist()
    )
470

471
472
473
    assert sum(seqlens[-1]) == seqlen
    assert all(s > 0 for s in seqlens[-1])

474
    total_entries = batch_size * 10
475
    cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
476
    cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda()
477
478
479

    dim = 4
    dstate = 8
480
    A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
481
482
    A_ref = A.clone()
    B_shape = [varBC_groups, dstate, seqlen]
483
    B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype)
484
485
    B_ref = B.clone()
    C_shape = [varBC_groups, dstate, seqlen]
486
    C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype)
487
488
489
490
491
    C_ref = C.clone()
    D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
    D_ref = D.clone()
    z = torch.randn(dim, seqlen, device=device, dtype=itype)
    z_ref = z.clone()
492
493
494
495
496
    delta_bias = (
        (0.5 * torch.rand(dim, device=device, dtype=torch.float32))
        if has_delta_bias
        else None
    )
497
498
    u = torch.randn(dim, seqlen, device=device, dtype=itype)
    u_ref = u.clone()
499
    delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)
500
501
502
    delta_ref = delta.clone()
    out = None
    out_ref = None
503
504

    prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
505
506
507
    prev_state = torch.randn(
        prev_state_shape, device=u.device, dtype=itype, requires_grad=False
    )
508
    prev_state_ref = prev_state.clone()
509
510
511
512
    state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[
        :batch_size
    ]
    unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
513
    unused_states_bool[state_indices] = False
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
    padded_state_indices = torch.concat(
        [
            state_indices,
            torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
        ],
        dim=-1,
    )

    has_initial_state = torch.randint(
        0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device
    )
    out = selective_scan_fn(
        u,
        prev_state,
        delta,
        A,
        B,
        C,
        D,
        z,
        delta_bias,
        delta_softplus,
        cumsum,
        padded_state_indices,
        has_initial_state,
    )
540
541
542
543
544
545
    outs_ref = []
    splits = [
        torch.split(var, seqlens[0], dim=-1)
        for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
    ]
    for i in range(len(seqlens[0])):
546
        u_s, delta_s, B_s, C_s, z_s = (v[i].unsqueeze(0) for v in splits)
547
548
        if padded_state_indices[i] == PAD_SLOT_ID:
            continue
549
550
551
552
553
554
555
556
557
558
559
        out_ref_s, _ = selective_scan_ref(
            u_s,
            delta_s,
            A_ref,
            B_s,
            C_s,
            D_ref,
            z=z_s,
            delta_bias=delta_bias,
            delta_softplus=delta_softplus,
            return_last_state=return_last_state,
560
            prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
561
562
563
564
            if has_initial_state[i]
            else None,
            final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0),
        )
565
        outs_ref.append(out_ref_s)
566
    out_ref = torch.cat(outs_ref, dim=-1)[0]
567

568
    unpadded_out = out[:, : out_ref[0].shape[-1]]
569
570
    print("Output diff max", (unpadded_out - out_ref).max())
    print("Output diff mean", (unpadded_out - out_ref).mean())
571
572
573
    print("Output state diff max", (prev_state - prev_state_ref).max())
    print("Output state diff mean", (prev_state - prev_state_ref).mean())
    assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
574
    assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    selective_scan_opcheck_fn(
        u,
        delta,
        A,
        B,
        C,
        D,
        z,
        delta_bias,
        delta_softplus,
        cumsum,
        padded_state_indices,
        has_initial_state,
        prev_state,
    )


592
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
593
@pytest.mark.parametrize("has_z", [True])
594
@pytest.mark.parametrize("dstate", [16, 64])
595
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
596
597
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
598
599
600
def test_selective_state_update_with_batch_indices(
    with_padding, dim, dstate, has_z, itype
):
601
602
603
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
    if itype == torch.bfloat16:
604
        rtol, atol = 1e-1, 1e-1
605
606
607
608
        if torch.version.hip:
            atol *= 2
    # set seed
    torch.random.manual_seed(0)
609
    batch_size = 3
610
611
    padding = 5 if with_padding else 0
    padded_batch_size = batch_size + padding
612
613
614
    total_entries = 10 * batch_size
    state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
    state_indices = torch.randperm(total_entries)[:batch_size].to(
615
616
617
        dtype=torch.int32, device=device
    )
    unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
618
    unused_states_bool[state_indices] = False
619
620
621
622
623
624
625
    padded_state_indices = torch.concat(
        [
            state_indices,
            torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
        ],
        dim=0,
    )
626
    x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
627
    out = torch.empty_like(x)
628
    dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
629
630
    dt_bias = torch.rand(dim, device=device) - 4.0
    A = -torch.rand(dim, dstate, device=device) - 1.0
631
632
    B = torch.randn(padded_batch_size, dstate, device=device)
    C = torch.randn(padded_batch_size, dstate, device=device)
633
634
    D = torch.randn(dim, device=device)
    z = torch.randn_like(x) if has_z else None
635
636
    state_ref = state[state_indices, :].clone()
    state_before = state.clone()
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        state_batch_indices=padded_state_indices,
        pad_slot_id=PAD_SLOT_ID,
        out=out,
    )
    out_ref = selective_state_update_ref(
        state_ref,
        x[:batch_size],
        dt[:batch_size],
        A,
        B[:batch_size],
        C[:batch_size],
        D=D,
        z=z[:batch_size],
        dt_bias=dt_bias,
        dt_softplus=True,
    )
664

665
666
    print("Output diff max", (out[:batch_size] - out_ref).max())
    print("Output diff mean", (out[:batch_size] - out_ref).mean())
667
    print("Output state diff max", (state[state_indices, :] - state_ref).max())
668
    print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
669
670
    # test padded entries stay the same
    if with_padding:
671
672
673
674
675
        assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
        assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
        assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
        assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
        assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
676
677

    # test "real" entries
678
    assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
679
    assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
680
681


682
@pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16])
683
684
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
685
686
@pytest.mark.parametrize("ngroups", [1, 4])
@pytest.mark.parametrize("dstate", [16, 64])
687
688
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
689
690
    dim, dstate, ngroups, has_z, tie_hdim, itype
):
691
692
693
694
695
696
    device = "cuda"
    rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
    if itype == torch.bfloat16:
        rtol, atol = 1e-1, 1e-1
    # set seed
    torch.random.manual_seed(0)
697
    batch_size = 3
698
699
700
701
    headdim = 64
    nheads = dim // headdim

    total_entries = 10 * batch_size
702
703
704
    state = torch.randn(
        total_entries, nheads, headdim, dstate, dtype=itype, device=device
    )
705
    state_indices = torch.randperm(total_entries)[:batch_size].to(
706
707
        dtype=torch.int32, device=device
    )
708
709

    x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
710
    out = torch.empty_like(x)
711
    if not tie_hdim:
712
        dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
713
714
715
716
        dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
        A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
        D = torch.randn(nheads, headdim, device=device)
    else:
717
718
719
720
721
722
723
724
725
        dt = repeat(
            torch.randn(batch_size, nheads, device=device, dtype=itype),
            "b h -> b h p",
            p=headdim,
        )
        dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
        A = repeat(
            -torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
        )
726
727
728
729
730
        D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
    B = torch.randn(batch_size, ngroups, dstate, device=device)
    C = torch.randn(batch_size, ngroups, dstate, device=device)
    z = torch.randn_like(x) if has_z else None
    state_ref = state[state_indices, :].detach().clone()
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
    selective_state_update(
        state,
        x,
        dt,
        A,
        B,
        C,
        D=D,
        z=z,
        dt_bias=dt_bias,
        dt_softplus=True,
        state_batch_indices=state_indices,
        pad_slot_id=PAD_SLOT_ID,
        out=out,
    )
    out_ref = selective_state_update_ref(
        state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
    )
749
750
751

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
752
    assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
753
    assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)