selective_scan_fwd.cu 41.3 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "selective_scan.h"

#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#ifdef USE_ROCM
    #include <c10/hip/HIPException.h>  // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
#else
    #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#endif

#ifndef USE_ROCM
    #include <cub/block/block_load.cuh>
    #include <cub/block/block_store.cuh>
    #include <cub/block/block_scan.cuh>
#else
    #include <hipcub/hipcub.hpp>
    namespace cub = hipcub;
#endif

#include "selective_scan.h"
#include "static_switch.h"

template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
         bool kIsVariableB_, bool kIsVariableC_,
         bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_, typename state_t_>
struct Selective_Scan_fwd_kernel_traits {
    static_assert(kNItems_ % 4 == 0);
    using input_t = input_t_;
    using weight_t = weight_t_;
    using state_t = state_t_;
    static constexpr int kNThreads = kNThreads_;
    // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
    static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
    static constexpr int kNItems = kNItems_;
    static constexpr int kNRows = kNRows_;
    static constexpr int kNBytes = sizeof(input_t);
    static_assert(kNBytes == 2 || kNBytes == 4);
    static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
    static_assert(kNItems % kNElts == 0);
    static constexpr int kNLoads = kNItems / kNElts;
    static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_;
    static constexpr bool kIsVariableB = kIsVariableB_;
    static constexpr bool kIsVariableC = kIsVariableC_;
    static constexpr bool kHasZ = kHasZ_;
    static constexpr bool kVarlen = kVarlen_;

    static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1;
    static constexpr int kNLoadsIndex = kNItems / 4;
    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
    using scan_t = float2;
    using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
    using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
        !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
    using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
    using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
        !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE  : cub::BLOCK_LOAD_DIRECT>;
    using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
    using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
        !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
    // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
    using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
    static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
                                                 sizeof(typename BlockLoadVecT::TempStorage),
                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
                                                 (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
                                                 sizeof(typename BlockStoreT::TempStorage),
                                                 sizeof(typename BlockStoreVecT::TempStorage)});
    static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
};

template<typename Ktraits>
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
void selective_scan_fwd_kernel(SSMParamsBase params) {
    constexpr bool kIsVariableB = Ktraits::kIsVariableB;
    constexpr bool kIsVariableC = Ktraits::kIsVariableC;
    constexpr bool kHasZ = Ktraits::kHasZ;
    constexpr bool kVarlen = Ktraits::kVarlen;
    constexpr int kNItems = Ktraits::kNItems;
    constexpr int kNRows = Ktraits::kNRows;
    constexpr bool kDirectIO = Ktraits::kDirectIO;
    using input_t = typename Ktraits::input_t;
    using weight_t = typename Ktraits::weight_t;
    using scan_t = typename Ktraits::scan_t;

    // Shared memory.
    extern __shared__ char smem_[];
    // cast to lvalue reference of expected type
    // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
    // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
    auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
    auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
    auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
    auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
    auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
    // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
    // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
    scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);

    const int batch_id = blockIdx.x;
    const int dim_id = blockIdx.y;
    const int group_id = dim_id / (params.dim_ngroups_ratio);
    int seqlen = params.seqlen;
    int sequence_start_index = batch_id;
    if constexpr (kVarlen){
        int *query_start_loc = reinterpret_cast<int *>(params.query_start_loc_ptr);
        sequence_start_index = query_start_loc[batch_id];
        seqlen = query_start_loc[batch_id + 1] - sequence_start_index;
    }
    const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
        : reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];

    const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
        : reinterpret_cast<int *>(params.cache_indices_ptr);
    const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; 
    // cache_index == params.pad_slot_id is defined as padding, so we exit early
    if (cache_index == params.pad_slot_id){
        return;
    }
    input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
        + dim_id * kNRows * params.u_d_stride;
    input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
        + dim_id * kNRows * params.delta_d_stride;
    weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
    weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
    input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
    weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
    input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;

    typename Ktraits::state_t *ssm_states;
    if (params.cache_enabled) {
        // APC mode: ssm_states points to the base, we'll use absolute cache slots later
        ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
            dim_id * kNRows * params.ssm_states_dim_stride;
    } else {
        // Non-APC mode: offset by cache_index as before
        ssm_states = reinterpret_cast<typename Ktraits::state_t *>(params.ssm_states_ptr) +
            cache_index * params.ssm_states_batch_stride +
            dim_id * kNRows * params.ssm_states_dim_stride;
    }
    
    float D_val[kNRows] = {0};
    if (params.D_ptr != nullptr) {
        #pragma unroll
        for (int r = 0; r < kNRows; ++r) {
            D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
        }
    }
    float delta_bias[kNRows] = {0};
    if (params.delta_bias_ptr != nullptr) {
        #pragma unroll
        for (int r = 0; r < kNRows; ++r) {
            delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
        }
    }

    // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
    const int block_size = params.cache_enabled ? params.block_size : 2048;

    const int* batch_cache_indices = cache_indices != nullptr ?
                                     cache_indices + batch_id * params.cache_indices_stride : nullptr;
    const int* block_idx_first_scheduled = params.block_idx_first_scheduled_token_ptr != nullptr ?
                                           reinterpret_cast<const int*>(params.block_idx_first_scheduled_token_ptr) : nullptr;
    const int* block_idx_last_scheduled = params.block_idx_last_scheduled_token_ptr != nullptr ?
                                          reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
    const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
                                   reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
    const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
                                 reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
    const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
                                    reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;

    const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;

    const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
                                 block_idx_first_scheduled[batch_id] : 0;

    // Determine chunk boundaries from pre-computed metadata (APC mode)
    // or fall back to simple block_size chunking.
    int first_chunk_idx, n_chunks;
    int current_position;

    if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
        const int last_chunk_idx = last_chunk_indices[batch_id];
        first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
        n_chunks = last_chunk_idx - first_chunk_idx + 1;
        // Derive current_position: if the first chunk is partial (fills remainder
        // of a started block), offset into the block accordingly.
        const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
        const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
                                        ? (block_size - first_chunk_tokens) : 0;
        current_position = block_idx_first * block_size + chunk_start_offset;
    } else {
        first_chunk_idx = 0;
        n_chunks = (seqlen + block_size - 1) / block_size;
        current_position = 0;
    }

    int tokens_processed = 0;

    for (int chunk = 0; chunk < n_chunks; ++chunk) {
        const int chunk_tokens = (cu_chunk_seqlen != nullptr)
            ? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
            : min(block_size, seqlen - tokens_processed);
        if (chunk_tokens <= 0) break;
        input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];

        __syncthreads();
        #pragma unroll
        for (int r = 0; r < kNRows; ++r) {
            if constexpr (!kDirectIO) {
                if (r > 0) { __syncthreads(); }
            }
            load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
            if constexpr (!kDirectIO) { __syncthreads(); }
            load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
        }
        u += chunk_tokens;
        delta += chunk_tokens;
    
        float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
        #pragma unroll
        for (int r = 0; r < kNRows; ++r) {
            #pragma unroll
            for (int i = 0; i < kNItems; ++i) {
                float u_val = float(u_vals[r][i]);
                delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
                if (params.delta_softplus) {
                    delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
                }
                delta_u_vals[r][i] = delta_vals[r][i] * u_val;
                out_vals[r][i] = D_val[r] * u_val;
            }
        }

        __syncthreads();
        for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
            weight_t A_val[kNRows];
            #pragma unroll
            for (int r = 0; r < kNRows; ++r) {
                A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
                // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
                constexpr float kLog2e = M_LOG2E;
                A_val[r] *= kLog2e;
            }
            // This variable holds B * C if both B and C are constant across seqlen. If only B varies
            // across seqlen, this holds C. If only C varies across seqlen, this holds B.
            // If both B and C vary, this is unused.
            weight_t BC_val[kNRows];
            weight_t B_vals[kNItems], C_vals[kNItems];
            if constexpr (kIsVariableB) {
                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
                    smem_load_weight, chunk_tokens);
                if constexpr (!kIsVariableC) {
                    #pragma unroll
                    for (int r = 0; r < kNRows; ++r) {
                        BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
                    }
                }
            }
            if constexpr (kIsVariableC) {
                auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
                load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
                    smem_load_weight_C, chunk_tokens);
                if constexpr (!kIsVariableB) {
                    #pragma unroll
                    for (int r = 0; r < kNRows; ++r) {
                        BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
                    }
                }
            }
            if constexpr (!kIsVariableB && !kIsVariableC) {
                #pragma unroll
                for (int r = 0; r < kNRows; ++r) {
                    BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
                }
            }

            #pragma unroll
            for (int r = 0; r < kNRows; ++r) {
                if (r > 0) { __syncthreads(); }  // Scan could be using the same smem
                scan_t thread_data[kNItems];
                #pragma unroll
                for (int i = 0; i < kNItems; ++i) {
                    thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
                                                 !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
                    if (threadIdx.x * kNItems + i >= chunk_tokens) {
                        thread_data[i] = make_float2(1.f, 0.f);
                    }
                }
                // Initialize running total
                scan_t running_prefix;
                if (chunk > 0) {
                    running_prefix = smem_running_prefix[state_idx + r * MAX_DSTATE];
                } else {
                    // Load initial state
                    if (params.cache_enabled && has_initial_state && batch_cache_indices != nullptr) {
                        size_t state_offset = load_cache_slot * params.ssm_states_batch_stride +
                                             r * params.ssm_states_dim_stride +
                                             state_idx * params.ssm_states_dstate_stride;
                        running_prefix = make_float2(1.0, float(ssm_states[state_offset]));
                    } else if (has_initial_state) {
                        // Non-APC mode: load from current batch position
                        running_prefix = make_float2(1.0, float(ssm_states[state_idx * params.ssm_states_dstate_stride]));
                    } else {
                        // No initial state
                        running_prefix = make_float2(1.0, 0.0);
                    }
                }

                SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
                typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
                    thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
                );
                // There's a syncthreads in the scan op, so we don't need to sync here.
                // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
                if (threadIdx.x == 0) {
                    smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;

                    // Store state at the end of each aligned chunk when cache is enabled
                    if (params.cache_enabled && batch_cache_indices != nullptr) {
                        size_t cache_slot;
                        if (chunk == n_chunks - 1) {
                            cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
                        } else {
                            const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
                            cache_slot = batch_cache_indices[block_idx_completed];
                        }

                        size_t state_offset = cache_slot * params.ssm_states_batch_stride +
                                             r * params.ssm_states_dim_stride +
                                             state_idx * params.ssm_states_dstate_stride;

                        ssm_states[state_offset] = typename Ktraits::state_t(prefix_op.running_prefix.y);
                    } else if (!params.cache_enabled && chunk == n_chunks - 1) {
                        // Non-APC mode: store only final state at current batch position
                        ssm_states[state_idx * params.ssm_states_dstate_stride] = typename Ktraits::state_t(prefix_op.running_prefix.y);
                    }
                }
                #pragma unroll
                for (int i = 0; i < kNItems; ++i) {
                    const weight_t C_val = !kIsVariableC
                        ? BC_val[r]
                        : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
                    out_vals[r][i] += thread_data[i].y * C_val;
                }
            }
        }
        input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
            + dim_id * kNRows * params.out_d_stride + tokens_processed;
        __syncthreads();
        #pragma unroll
        for (int r = 0; r < kNRows; ++r) {
            if constexpr (!kDirectIO) {
                if (r > 0) { __syncthreads(); }
            }
            store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
        }

        if constexpr (kHasZ) {
            input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
                + dim_id * kNRows * params.z_d_stride + tokens_processed;
            input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
                + dim_id * kNRows * params.out_z_d_stride + tokens_processed;
            #pragma unroll
            for (int r = 0; r < kNRows; ++r) {
                input_t z_vals[kNItems];
                __syncthreads();
                load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
                #pragma unroll
                for (int i = 0; i < kNItems; ++i) {
                    float z_val = z_vals[i];
                    out_vals[r][i] *= z_val / (1 + expf(-z_val));
                }
                __syncthreads();
                store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
            }
        }

        Bvar += chunk_tokens;
        Cvar += chunk_tokens;

        tokens_processed += chunk_tokens;
        current_position += chunk_tokens;
    }
}

template<int kNThreads, int kNItems, typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
    // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
    // processing 1 row.
    constexpr int kNRows = 1;
    // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
    constexpr bool kIsVariableB = true;
    constexpr bool kIsVariableC = true;
    BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
        BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
            BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
                using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ,  kVarlen, input_t, weight_t, state_t>;
                constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
                dim3 grid(params.batch, params.dim / kNRows);
                auto kernel = &selective_scan_fwd_kernel<Ktraits>;
                if (kSmemSize >= 48 * 1024) {
#ifdef USE_ROCM
                    C10_HIP_CHECK(hipFuncSetAttribute(
                        reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
                    C10_CUDA_CHECK(cudaFuncSetAttribute(
                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#endif
                }
                kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
                C10_CUDA_KERNEL_LAUNCH_CHECK();
            });
        });
    });
}

template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {

    #ifndef USE_ROCM
        if (params.cache_enabled && params.block_size == 1024) {
            selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 128) {
            selective_scan_fwd_launch<32, 4, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 256) {
            selective_scan_fwd_launch<32, 8, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 512) {
            selective_scan_fwd_launch<32, 16, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 1024) {
            selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
        } else {
            selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
        }
    #else
        if (params.cache_enabled && params.block_size == 1024) {
            selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 256) {
            selective_scan_fwd_launch<64, 4, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 512) {
            selective_scan_fwd_launch<64, 8, input_t, weight_t, state_t>(params, stream);
        } else if (params.seqlen <= 1024) {
            selective_scan_fwd_launch<64, 16, input_t, weight_t, state_t>(params, stream);
        } else {
            selective_scan_fwd_launch<128, 16, input_t, weight_t, state_t>(params, stream);
        }
    #endif
}

template void selective_scan_fwd_cuda<at::BFloat16, float, at::BFloat16>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::BFloat16, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float, at::Half>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<at::Half, float, float>(SSMParamsBase &params, cudaStream_t stream);
template void selective_scan_fwd_cuda<float, float, float>(SSMParamsBase &params, cudaStream_t stream);

#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, STYPE, NAME, ...)       \
    if (ITYPE == at::ScalarType::Half) {                                            \
        using input_t = at::Half;                                                   \
        using weight_t = float;                                                     \
        if (STYPE == at::ScalarType::Half) {                                        \
            using state_t = at::Half;                                               \
            __VA_ARGS__();                                                          \
        } else if (STYPE == at::ScalarType::Float) {                                \
            using state_t = float;                                                  \
            __VA_ARGS__();                                                          \
        } else {                                                                    \
            AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
        }                                                                           \
    } else if (ITYPE == at::ScalarType::BFloat16) {                                 \
        using input_t = at::BFloat16;                                               \
        using weight_t = float;                                                     \
        if (STYPE == at::ScalarType::BFloat16) {                                    \
            using state_t = at::BFloat16;                                           \
            __VA_ARGS__();                                                          \
        } else if (STYPE == at::ScalarType::Float) {                                \
            using state_t = float;                                                  \
            __VA_ARGS__();                                                          \
        } else {                                                                    \
            AT_ERROR(#NAME, " not implemented for state type '", toString(STYPE), "'"); \
        }                                                                           \
    } else if (ITYPE == at::ScalarType::Float)  {                                   \
        using input_t = float;                                                      \
        using weight_t = float;                                                     \
        using state_t = float;                                                      \
        __VA_ARGS__();                                                              \
    } else {                                                                        \
        AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
    }


template<typename input_t, typename weight_t, typename state_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);

void set_ssm_params_fwd(SSMParamsBase &params,
                        // sizes
                        const size_t batch,
                        const size_t dim,
                        const size_t seqlen,
                        const size_t dstate,
                        const size_t n_groups,
                        const bool is_variable_B,
                        const bool is_variable_C,
                        // device pointers
                        const torch::Tensor u,
                        const torch::Tensor delta,
                        const torch::Tensor A,
                        const torch::Tensor B,
                        const torch::Tensor C,
                        const torch::Tensor out,
                        const torch::Tensor z,
                        const torch::Tensor out_z,
                        const std::optional<at::Tensor>& D,
                        const std::optional<at::Tensor>& delta_bias,
                        const torch::Tensor ssm_states,
                        bool has_z,
                        bool delta_softplus,
                        const std::optional<at::Tensor>& query_start_loc,
                        const std::optional<at::Tensor>& cache_indices,
                        const std::optional<at::Tensor>& has_initial_state,
                        bool varlen,
                        int64_t pad_slot_id,
                        int64_t block_size,
                        const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
                        const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
                        const std::optional<torch::Tensor> &initial_state_idx,
                        const std::optional<torch::Tensor> &cu_chunk_seqlen,
                        const std::optional<torch::Tensor> &last_chunk_indices) {

    // Reset the parameters
    memset(&params, 0, sizeof(params));

    params.batch = batch;
    params.dim = dim;
    params.seqlen = seqlen;
    params.dstate = dstate;
    params.n_groups = n_groups;
    params.dim_ngroups_ratio = dim / n_groups;
    params.pad_slot_id = pad_slot_id;

    params.delta_softplus = delta_softplus;

    params.is_variable_B = is_variable_B;
    params.is_variable_C = is_variable_C;

    // Set the pointers and strides.
    params.u_ptr = u.data_ptr();
    params.delta_ptr = delta.data_ptr();
    params.A_ptr = A.data_ptr();
    params.B_ptr = B.data_ptr();
    params.C_ptr = C.data_ptr();
    params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr;
    params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr;
    params.out_ptr = out.data_ptr();
    params.ssm_states_ptr = ssm_states.data_ptr();
    params.z_ptr = has_z ? z.data_ptr() : nullptr;
    params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
    params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
    params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
    params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;

    // Set cache parameters - cache is enabled if we have direct cache writing params
    params.cache_enabled = block_idx_first_scheduled_token.has_value();
    params.block_size = static_cast<int>(block_size);

    // Set direct cache writing pointers
    params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
    params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
    params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
    params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
    params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;

    // All stride are in elements, not bytes.
    params.A_d_stride = A.stride(0);
    params.A_dstate_stride = A.stride(1);

    if (varlen){
        params.B_batch_stride = B.stride(2);
        params.B_group_stride = B.stride(0);
        params.B_dstate_stride = B.stride(1);
        params.C_batch_stride = C.stride(2);
        params.C_group_stride = C.stride(0);
        params.C_dstate_stride = C.stride(1);

        params.u_batch_stride = u.stride(1);
        params.u_d_stride = u.stride(0);
        params.delta_batch_stride = delta.stride(1);
        params.delta_d_stride = delta.stride(0);
        if (has_z) {
            params.z_batch_stride = z.stride(1);
            params.z_d_stride = z.stride(0);
            params.out_z_batch_stride = out_z.stride(1);
            params.out_z_d_stride = out_z.stride(0);
        }
        params.out_batch_stride = out.stride(1);
        params.out_d_stride = out.stride(0);

        params.ssm_states_batch_stride = ssm_states.stride(0);
        params.ssm_states_dim_stride = ssm_states.stride(1);
        params.ssm_states_dstate_stride = ssm_states.stride(2);

        params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;

    }
    else{
        if (!is_variable_B) {
            params.B_d_stride = B.stride(0);
        } else {
            params.B_batch_stride = B.stride(0);
            params.B_group_stride = B.stride(1);
        }
        params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
        if (!is_variable_C) {
            params.C_d_stride = C.stride(0);
        } else {
            params.C_batch_stride = C.stride(0);
            params.C_group_stride = C.stride(1);
        }
        params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
        params.u_batch_stride = u.stride(0);
        params.u_d_stride = u.stride(1);
        params.delta_batch_stride = delta.stride(0);
        params.delta_d_stride = delta.stride(1);
        if (has_z) {
            params.z_batch_stride = z.stride(0);
            params.z_d_stride = z.stride(1);
            params.out_z_batch_stride = out_z.stride(0);
            params.out_z_d_stride = out_z.stride(1);
        }
        params.out_batch_stride = out.stride(0);
        params.out_d_stride = out.stride(1);
        
        params.ssm_states_batch_stride = ssm_states.stride(0);
        params.ssm_states_dim_stride = ssm_states.stride(1);
        params.ssm_states_dstate_stride = ssm_states.stride(2);

        params.cache_indices_stride = cache_indices.has_value() ? cache_indices.value().stride(0) : 0;
    }
}

void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
                  const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
                  const std::optional<torch::Tensor> &D_,
                  const std::optional<torch::Tensor> &z_,
                  const std::optional<torch::Tensor> &delta_bias_,
                  bool delta_softplus,
                  const std::optional<torch::Tensor> &query_start_loc,
                  const std::optional<torch::Tensor> &cache_indices,
                  const std::optional<torch::Tensor> &has_initial_state,
                  const torch::Tensor &ssm_states,
                  // used to identify padding entries if cache_indices provided
                  // in case of padding, the kernel will return early
                  int64_t pad_slot_id,
                  int64_t block_size,
                  const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
                  const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
                  const std::optional<torch::Tensor> &initial_state_idx,
                  const std::optional<torch::Tensor> &cu_chunk_seqlen,
                  const std::optional<torch::Tensor> &last_chunk_indices) {
    auto input_type = u.scalar_type();
    auto weight_type = A.scalar_type();
    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
    TORCH_CHECK(weight_type == at::ScalarType::Float);

    const bool is_variable_B = B.dim() >= 3;
    const bool is_variable_C = C.dim() >= 3;

    TORCH_CHECK(delta.scalar_type() == input_type);
    TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
    TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));

    TORCH_CHECK(u.is_cuda());
    TORCH_CHECK(delta.is_cuda());
    TORCH_CHECK(A.is_cuda());
    TORCH_CHECK(B.is_cuda());
    TORCH_CHECK(C.is_cuda());

    TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
    TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);

    const auto sizes = u.sizes();
    const bool varlen = query_start_loc.has_value();
    const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
    const int dim = varlen ? sizes[0] : sizes[1];
    const int seqlen = varlen ? sizes[1] : sizes[2];
    const int dstate = A.size(1);
    const int n_groups = varlen ? B.size(0) : B.size(1);

    TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");

    if (varlen) {
        CHECK_SHAPE(u, dim, seqlen);
        CHECK_SHAPE(delta, dim, seqlen);
    } else {
        CHECK_SHAPE(u, batch_size, dim, seqlen);
        CHECK_SHAPE(delta, batch_size, dim, seqlen);
    }
    CHECK_SHAPE(A, dim, dstate);
    TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
    if (varlen) {
        CHECK_SHAPE(B, n_groups, dstate, seqlen);
    } else {
        CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); 
    }
    TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);

    TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
    if (varlen) {
        CHECK_SHAPE(C, n_groups, dstate, seqlen);
    } else {
        CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); 
    }
    TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);

    if (D_.has_value()) {
        auto D = D_.value();
        TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
        TORCH_CHECK(D.is_cuda());
        TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
        CHECK_SHAPE(D, dim);
    }

    if (delta_bias_.has_value()) {
        auto delta_bias = delta_bias_.value();
        TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
        TORCH_CHECK(delta_bias.is_cuda());
        TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
        CHECK_SHAPE(delta_bias, dim);
    }


    if (has_initial_state.has_value()) {
        auto has_initial_state_ = has_initial_state.value();
        TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
        TORCH_CHECK(has_initial_state_.is_cuda());
        CHECK_SHAPE(has_initial_state_, batch_size);
    }


    if (query_start_loc.has_value()) {
        auto query_start_loc_ = query_start_loc.value();
        TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
        TORCH_CHECK(query_start_loc_.is_cuda());
    }


    if (cache_indices.has_value()) {
        auto cache_indices_ = cache_indices.value();
        TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
        TORCH_CHECK(cache_indices_.is_cuda());

        // cache_indices can be either 1D (batch_size,) for non-APC mode
        // or 2D (batch_size, max_positions) for APC mode
        const bool is_apc_mode = block_idx_first_scheduled_token.has_value();
        if (is_apc_mode) {
            TORCH_CHECK(cache_indices_.dim() == 2, "cache_indices must be 2D for APC mode");
            TORCH_CHECK(cache_indices_.size(0) == batch_size, "cache_indices first dimension must match batch_size");
        } else {
            CHECK_SHAPE(cache_indices_, batch_size);
        }
    }
   

    at::Tensor z, out_z;
    const bool has_z = z_.has_value();
    if (has_z) {
        z = z_.value();
        TORCH_CHECK(z.scalar_type() == input_type);
        TORCH_CHECK(z.is_cuda());
        TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
        if (varlen){
            CHECK_SHAPE(z, dim, seqlen);
        } else {
            CHECK_SHAPE(z, batch_size, dim, seqlen);
        }
        
        out_z = z;
    }

    // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
    at::Tensor out = delta;
    // ssm_states can now be either the same as input_type or float32
    auto state_type = ssm_states.scalar_type();
    TORCH_CHECK(state_type == input_type || state_type == at::ScalarType::Float);
    TORCH_CHECK(ssm_states.is_cuda());
    TORCH_CHECK(ssm_states.stride(-1) == 1);

    SSMParamsBase params;
    set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
                       u, delta, A, B, C, out, z, out_z,
                       D_,
                       delta_bias_,
                       ssm_states,
                       has_z,
                       delta_softplus,
                       query_start_loc,
                       cache_indices,
                       has_initial_state,
                       varlen,
                       pad_slot_id,
                       block_size,
                       block_idx_first_scheduled_token,
                       block_idx_last_scheduled_token,
                       initial_state_idx,
                       cu_chunk_seqlen,
                       last_chunk_indices
                       );

    
    const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
    auto stream = at::cuda::getCurrentCUDAStream().stream();
    DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), ssm_states.scalar_type(), "selective_scan_fwd", [&] {
        selective_scan_fwd_cuda<input_t, weight_t, state_t>(params, stream);
    });
}