test_mamba_ssm_ssd.py 21.2 KB
Newer Older
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
3
4
5
6
7
8
9
10
11

import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

from vllm.model_executor.layers.mamba.ops.ssd_combined import (
    mamba_chunk_scan_combined)
from vllm.platforms import current_platform
12
from vllm.v1.attention.backends.mamba2_attn import (
13
    _query_start_loc_to_chunk_indices_offsets)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

# Added by the IBM Team, 2024

# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py


# this is the segsum implementation taken from above
def segsum(x):
    """Calculates segment sum."""
    T = x.size(-1)
    x = repeat(x, "... d -> ... d e", e=T)
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
                      diagonal=-1)
    x = x.masked_fill(~mask, 0)
    x_segsum = torch.cumsum(x, dim=-2)
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool),
                      diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum


def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
    """
    Arguments:
        X: (batch, length, n_heads, d_head)
        A: (batch, length, n_heads)
        B: (batch, length, n_heads, d_state)
        C: (batch, length, n_heads, d_state)
    Return:
        Y: (batch, length, n_heads, d_head)
    """
    assert X.dtype == A.dtype == B.dtype == C.dtype
    assert X.shape[1] % block_len == 0

    # Rearrange into blocks/chunks
    X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len)
                  for x in (X, A, B, C))

    A = rearrange(A, "b c l h -> b h c l")
    A_cumsum = torch.cumsum(A, dim=-1)

    # 1. Compute the output for each intra-chunk (diagonal blocks)
    L = torch.exp(segsum(A))
    Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)

    # 2. Compute the state for each intra-chunk
    # (right term of low-rank factorization of off-diagonal blocks; B terms)
    decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
    states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)

    # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
    #    chunk boundaries
    # (middle term of factorization of off-diag blocks; A terms)
    if initial_states is None:
        initial_states = torch.zeros_like(states[:, :1])
    states = torch.cat([initial_states, states], dim=1)
    decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
    new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
    states, final_state = new_states[:, :-1], new_states[:, -1]

    # 4. Compute state -> output conversion per chunk
    # (left term of low-rank factorization of off-diagonal blocks; C terms)
    state_decay_out = torch.exp(A_cumsum)
    Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)

    # Add output of intra-chunk and inter-chunk terms
    # (diagonal and off-diagonal blocks)
    Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
    return Y, final_state


def generate_random_inputs(batch_size,
                           seqlen,
                           n_heads,
                           d_head,
                           itype,
                           device='cuda'):

    current_platform.seed_everything(0)
    A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
    dt = F.softplus(
        torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
        4)
    X = torch.randn((batch_size, seqlen, n_heads, d_head),
                    dtype=itype,
                    device=device)
    B = torch.randn((batch_size, seqlen, n_heads, d_head),
                    dtype=itype,
                    device=device)
    C = torch.randn((batch_size, seqlen, n_heads, d_head),
                    dtype=itype,
                    device=device)

    return A, dt, X, B, C


110
111
112
113
114
115
116
117
def generate_continuous_batched_examples(example_lens_by_batch,
                                         num_examples,
                                         full_length,
                                         last_taken,
                                         exhausted,
                                         n_heads,
                                         d_head,
                                         itype,
118
119
                                         device='cuda',
                                         return_naive_ref=True):
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
120
121
122

    # this function generates a random examples of certain length
    # and then cut according to "example_lens_by_batch" and feed
123
124
125
126
    # them in continuous batches to the kernels.
    # If if return_naive_ref=True, the naive torch implementation
    # ssd_minimal_discrete will be used to compute and return
    # reference output.
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
127
128
129
130
131

    # generate the full-length example
    A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
                                            d_head, itype)

132
133
134
135
136
137
138
    if return_naive_ref:
        Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
                                                      A * dt,
                                                      B,
                                                      C,
                                                      block_len=full_length //
                                                      4)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
139
140
141
142
143

    # internal function that outputs a cont batch of examples
    # given a tuple of lengths for each example in the batch
    # e.g., example_lens=(8, 4) means take 8 samples from first eg,
    #       4 examples from second eg, etc
144
    def get_continuous_batch(example_lens: tuple[int, ...]):
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

        indices = []
        for i, x in enumerate(example_lens):
            c = last_taken.get(i, 0)
            indices.append((c, c + x))
            last_taken[i] = (c + x) % full_length
            exhausted[i] = last_taken[i] == 0

        return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
                              ]).unsqueeze(0) for x in (dt, X, B, C))

    # internal function that maps "n" to the appropriate right boundary
    # value when forming continuous batches from examples of length given
    # by "full_length".
    # - e.g., when n > full_length, returns n % full_length
    #         when n == full_length, returns full_length
    def end_boundary(n: int):
        return n - ((n - 1) // full_length) * full_length

    IND_E = None
    for spec in example_lens_by_batch:

        # get the (maybe partial) example seen in this cont batch
        dt2, X2, B2, C2 = get_continuous_batch(spec)

        # get the metadata
        cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
172
        seq_idx = torch.zeros(cu_seqlens[-1],
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
173
174
175
176
177
178
                              dtype=torch.int32,
                              device=cu_seqlens.device)
        for i, (srt, end) in enumerate(zip(
                cu_seqlens,
                cu_seqlens[1:],
        )):
179
            seq_idx[srt:end] = i
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
180
181
182
183
184
185
186
187

        # for cont batch
        if IND_E is None:
            IND_S = [0 for _ in range(len(spec))]
        else:
            IND_S = [x % full_length for x in IND_E]
        IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]

188
189
        yield ([Y_min[s, IND_S[s]:IND_E[s]]
                for s in range(num_examples)] if return_naive_ref else None,
190
               cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
191
192
193
194
195
196


@pytest.mark.parametrize("itype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
197
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
198
199
200
201
202
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
                                         itype):

    # this tests the kernels on a single example (no batching)

203
204
205
206
207
208
209
    # TODO: the bfloat16 case requires higher thresholds. To be investigated

    if itype == torch.bfloat16:
        atol, rtol = 5e-2, 5e-2
    else:
        atol, rtol = 8e-3, 5e-3

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
210
211
212
213
214
215
216
217
218
219
220
221
    # set seed
    batch_size = 1  # batch_size
    # ssd_minimal_discrete requires chunk_size divide seqlen
    # - this is only required for generating the reference seqs,
    #   it is not an operational limitation.
    seqlen, chunk_size = seq_len_chunk_size

    A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
                                            d_head, itype)

    Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
                                                  B, C, chunk_size)
222
223
224
225
226
227
228
229
230
231
    Y = torch.empty_like(X)
    final_state = mamba_chunk_scan_combined(X,
                                            dt,
                                            A,
                                            B,
                                            C,
                                            chunk_size,
                                            D=None,
                                            return_final_states=True,
                                            out=Y)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
232
233

    # just test the last in sequence
234
    torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
235
236
237

    # just test the last head
    # NOTE, in the kernel we always cast states to fp32
238
239
240
241
    torch.testing.assert_close(final_state[:, -1],
                               final_state_min[:, -1].to(torch.float32),
                               atol=atol,
                               rtol=rtol)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267


@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
    "seq_len_chunk_size_cases",
    [

        # small-ish chunk_size (8)
        (64, 8, 2, [(64, 32), (64, 32)]),
        (64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
        (64, 8, 2, [(8, 8), (8, 8), (8, 8)]),  # chunk size boundary
        (64, 8, 2, [(4, 4), (4, 4), (4, 4),
                    (4, 4)]),  # chunk_size larger than cont batches
        (64, 8, 5, [
            (64, 32, 16, 8, 8),
            (8, 16, 32, 16, 8),
            (8, 8, 16, 32, 16),
        ]),  # mode examples with varied lengths

        # large-ish chunk_size (256)
        (64, 256, 1, [(5, ), (1, ), (1, ),
                      (1, )]),  # irregular sizes with small sequences
        (64, 256, 2, [(5, 30), (1, 2), (1, 2),
                      (1, 2)]),  # irregular sizes with small sequences
268
269
270
271

        # we also need to test some large seqlen
        # to catch errors with init states decay
        (768, 128, 2, [(138, 225), (138, 225)]),
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
272
273
274
275
276
277
278
279
280
    ])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
                                     itype):

    # this test with multiple examples in a continuous batch
    # (i.e. chunked prefill)

    seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases

281
282
283
    # This test can have larger error for longer sequences
    if seqlen > 256:
        atol, rtol = 1e-2, 5e-3
284
285
286
    else:
        atol, rtol = 5e-3, 5e-3

Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
287
288
    # hold state during the cutting process so we know if an
    # example has been exhausted and needs to cycle
289
290
    last_taken: dict = {}  # map: eg -> pointer to last taken sample
    exhausted: dict = {}  # map: eg -> boolean indicating example is exhausted
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
291
292

    states = None
293
294
295
296
    for Y_min, cu_seqlens, seq_idx, (
            A, dt, X, B, C) in generate_continuous_batched_examples(
                cases, num_examples, seqlen, last_taken, exhausted, n_heads,
                d_head, itype):
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
297

298
299
300
        chunk_indices, chunk_offsets = \
            _query_start_loc_to_chunk_indices_offsets(
                cu_seqlens, chunk_size, cu_seqlens[-1])
301

302
303
        Y = torch.empty_like(X)
        new_states = mamba_chunk_scan_combined(
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
304
305
306
307
308
309
310
311
            X,
            dt,
            A,
            B,
            C,
            chunk_size,
            D=None,
            cu_seqlens=cu_seqlens,
312
313
314
            seq_idx=seq_idx,
            chunk_indices=chunk_indices,
            chunk_offsets=chunk_offsets,
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
315
316
            return_varlen_states=True,
            initial_states=states,
317
            out=Y,
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
318
319
320
321
322
323
324
325
        )

        # just test the last in sequence
        for i in range(num_examples):

            # just test one dim and dstate
            Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
            Y_min_eg = Y_min[i][:, 0, 0]
326
            torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
327
328
329
330
331
332
333

        # update states
        states = new_states
        for i, clear in exhausted.items():
            if clear:
                states[i].fill_(0.)
                exhausted[i] = False
334
335
336
337
338
339
340
341
342
343
344
345
346
347


@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize("seqlens", [
    (16, 2, 8, 13),
    (270, 88, 212, 203),
    (16, 20),
])
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):

    # This test verifies the correctness of the chunked prefill implementation
    # in the mamba2 ssd kernels, by comparing concatenation (in the sequence
    # dimension) of chunked results with the full sequence result.
    # It is different from test_mamba_chunk_scan_cont_batch by:
co63oc's avatar
co63oc committed
348
    # 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    #    reference outputs. Instead, it compares chunked kernel outputs to full
    #    sequence kernel outputs. This is the most straightforward way to
    #    assert chunked prefill correctness.
    # 2. It focuses on cases where sequences change in the middle of mamba
    #    chunks, and not necessarily on chunk boundaries.

    max_seqlen = max(seqlens)
    # This test can have larger error for longer sequences
    if max_seqlen > 256:
        atol, rtol = 1e-2, 5e-3
    else:
        atol, rtol = 5e-3, 5e-3

    num_sequences = len(seqlens)
    n_heads = 16
    d_head = 64
    itype = torch.float32

    # hold state during the cutting process so we know if an
    # example has been exhausted and needs to cycle
    last_taken: dict = {}  # map: eg -> pointer to last taken sample
    exhausted: dict = {}  # map: eg -> boolean indicating example is exhausted
    _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
        generate_continuous_batched_examples([seqlens],
                                             num_sequences,
                                             max_seqlen,
                                             last_taken,
                                             exhausted,
                                             n_heads,
                                             d_head,
                                             itype,
                                             return_naive_ref=False))
    seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
    device = X.device

    ## full seqlen computation
    chunk_indices, chunk_offsets = \
            _query_start_loc_to_chunk_indices_offsets(
                cu_seqlens, chunk_size, cu_seqlens[-1])
    Y_ref = torch.empty_like(X)
    state_ref = mamba_chunk_scan_combined(
        X,
        dt,
        A,
        B,
        C,
        chunk_size,
        D=None,
        cu_seqlens=cu_seqlens,
        seq_idx=seq_idx,
        chunk_indices=chunk_indices,
        chunk_offsets=chunk_offsets,
        return_varlen_states=True,
        initial_states=None,
        out=Y_ref,
    )

    ## chunked seqlen computation
    # first chunk
    chunked_seqlens = seqlens // 2
    chunked_cu_seqlens = torch.cat([
        torch.tensor([0], device=device),
        torch.cumsum(chunked_seqlens, dim=0)
    ],
                                   dim=0)
    chunked_seq_idx = torch.repeat_interleave(
        torch.arange(len(chunked_seqlens), device=device),
        chunked_seqlens,
        output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
    chunked_input_seq_len = chunked_cu_seqlens[-1]
    X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
    dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
    B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
    C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
    for i in range(num_sequences):
        # fmt: off
        chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...]  # noqa: E501

        X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i)  # noqa: E501
        dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i)  # noqa: E501
        B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i)  # noqa: E501
        C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i)  # noqa: E501
        # fmt: on

    chunk_indices, chunk_offsets = \
            _query_start_loc_to_chunk_indices_offsets(
                chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
    Y_partial = torch.empty_like(X_chunked)
    partial_state = mamba_chunk_scan_combined(
        X_chunked,
        dt_chunked,
        A,
        B_chunked,
        C_chunked,
        chunk_size,
        D=None,
        cu_seqlens=chunked_cu_seqlens,
        seq_idx=chunked_seq_idx,
        chunk_indices=chunk_indices,
        chunk_offsets=chunk_offsets,
        return_varlen_states=True,
        initial_states=None,
        out=Y_partial,
    )

    # remaining chunk
    remaining_chunked_seqlens = seqlens - chunked_seqlens
    remaining_chunked_cu_seqlens = torch.cat([
        torch.tensor([0], device=device),
        torch.cumsum(remaining_chunked_seqlens, dim=0)
    ],
                                             dim=0)
    remaining_chunked_seq_idx = torch.repeat_interleave(
        torch.arange(len(remaining_chunked_seqlens), device=device),
        remaining_chunked_seqlens,
        output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
            torch.int32)
    remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
    # fmt: off
    remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...]  # noqa: E501
    remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...]  # noqa: E501
    remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...]  # noqa: E501
    remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...]  # noqa: E501
    for i in range(num_sequences):
        remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...]  # noqa: E501

        remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i)  # noqa: E501
        remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i)  # noqa: E501
        remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i)  # noqa: E501
        remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i)  # noqa: E501

    # assert input chunking is correct
    concat_chunk_f = lambda pt1, pt2, i: torch.cat([
        pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
        pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
        ],
        dim=1)
    concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1)  # noqa: E501
    # fmt: on

    assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
    assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
    assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
    assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)

    chunk_indices, chunk_offsets = \
            _query_start_loc_to_chunk_indices_offsets(
                remaining_chunked_cu_seqlens,
                chunk_size,
                remaining_chunked_cu_seqlens[-1])

    Y_chunked = torch.empty_like(remaining_X_chunked)
    state_chunked = mamba_chunk_scan_combined(
        remaining_X_chunked,
        remaining_dt_chunked,
        A,
        remaining_B_chunked,
        remaining_C_chunked,
        chunk_size,
        D=None,
        cu_seqlens=remaining_chunked_cu_seqlens,
        seq_idx=remaining_chunked_seq_idx,
        chunk_indices=chunk_indices,
        chunk_offsets=chunk_offsets,
        return_varlen_states=True,
        initial_states=partial_state,
        out=Y_chunked,
    )
    Y = concat_batch_f(Y_partial, Y_chunked)

    # kernel chunked is same as kernel overall
    for i in range(num_sequences):
        Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
        Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
        torch.testing.assert_close(
            Y_seq[:, :chunked_seqlens[i], ...],
            Y_ref_seq[:, :chunked_seqlens[i], ...],
            atol=atol,
            rtol=rtol,
            msg=lambda x: f"seq{i} output part1 " + x)  # noqa: B023
        torch.testing.assert_close(
            Y_seq[:, chunked_seqlens[i]:, ...],
            Y_ref_seq[:, chunked_seqlens[i]:, ...],
            atol=atol,
            rtol=rtol,
            msg=lambda x: f"seq{i} output part2 " + x)  # noqa: B023

        state_seq = state_chunked[i]
        state_seq_ref = state_ref[i]
        torch.testing.assert_close(
            state_seq,
            state_seq_ref,
            atol=atol,
            rtol=rtol,
            msg=lambda x: f"seq{i} state " + x)  # noqa: B023