mask.h 40.9 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
/******************************************************************************
 * Copyright (c) 2024, Tri Dao.
 ******************************************************************************/

#pragma once

#include <cute/tensor.hpp>

namespace flash {

using namespace cute;

template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
                                           const int col_idx_offset_ = 0) {
    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    // 计算块内线程位置
    const int lane_id = threadIdx.x % 64;
    const int col_idx_offset = col_idx_offset_ + lane_id / 16;
    const int stride_between_each_repeat = 16;
    const int stride_between_each_thread = 4;

    #pragma unroll
    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
        const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
        #pragma unroll
        for (int j = 0; j < size<1, 0>(tensor); ++j) {
            const int col_idx = col_idx_base + j * stride_between_each_thread;
            if (col_idx >= max_seqlen_k) {
                // Without the "make_coord" we get wrong results
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                }
            }
        }
    }
}

template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_continuous(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
                                           const int col_idx_offset_ = 0) {
    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    // 计算块内线程位置
    const int lane_id = threadIdx.x % 64;
    const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
    const int stride_between_each_repeat = 16;
    const int stride_between_each_thread = 1;

    #pragma unroll
    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
        const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
        #pragma unroll
        for (int j = 0; j < size<1, 0>(tensor); ++j) {
            const int col_idx = col_idx_base + j * stride_between_each_thread;
            if (col_idx >= max_seqlen_k) {
                // Without the "make_coord" we get wrong results
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                }
            }
        }
    }
}

template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
                                        const int max_seqlen_k, const int row_idx_offset,
                                        const int max_seqlen_q, const int warp_row_stride,
                                        const int window_size_left, const int window_size_right) {
    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    const int lane_id = threadIdx.x % 64;
    const int col_idx_offset = col_idx_offset_ + lane_id / 16;
    const int stride_between_each_repeat = 16;
    const int stride_between_each_thread = 4;

    #pragma unroll
    for (int mi = 0; mi < size<0>(tensor); ++mi) {
        const int row_idx_base = row_idx_offset + mi * warp_row_stride;
        const int row_idx = row_idx_base;
        const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
        const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
        #pragma unroll
        for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
            const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
            #pragma unroll
            for (int j = 0; j < size<1, 0>(tensor); ++j) {
                const int col_idx = col_idx_base + j * stride_between_each_thread;
                if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                }
            }
        }
    }
}

template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local_continuous(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
                                        const int max_seqlen_k, const int row_idx_offset,
                                        const int max_seqlen_q, const int warp_row_stride,
                                        const int window_size_left, const int window_size_right) {
    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    const int lane_id = threadIdx.x % 64;
    const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
    const int stride_between_each_repeat = 16;
    const int stride_between_each_thread = 1;

    #pragma unroll
    for (int mi = 0; mi < size<0>(tensor); ++mi) {
        const int row_idx_base = row_idx_offset + mi * warp_row_stride;
        const int row_idx = row_idx_base;
        const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
        const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
        #pragma unroll
        for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
            const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
            #pragma unroll
            for (int j = 0; j < size<1, 0>(tensor); ++j) {
                const int col_idx = col_idx_base + j * stride_between_each_thread;
                if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                }
            }
        }
    }
}

template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
                                         const int max_seqlen_k, const int row_idx_offset,
                                         const int max_seqlen_q, const int warp_row_stride) {
    // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
    apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
                                          max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal_continuous(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
                                         const int max_seqlen_k, const int row_idx_offset,
                                         const int max_seqlen_q, const int warp_row_stride) {
    // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
    apply_mask_local_continuous</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
                                          max_seqlen_q, warp_row_stride, -1, 0);
}



template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_trans(Tensor<Engine, Layout> &tensor, const int max_seqlen_q,
                                           const int col_idx_offset_ = 0) {
    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    // 计算块内线程位置
    const int lane_id = threadIdx.x % 64;
    const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
    const int stride_between_each_repeat = 16;
    const int stride_between_each_thread = 1;

    #pragma unroll
    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
        const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
        #pragma unroll
        for (int j = 0; j < size<1, 0>(tensor); ++j) {
            const int col_idx = col_idx_base + j * stride_between_each_thread;
            if (col_idx >= max_seqlen_q) {
                // Without the "make_coord" we get wrong results
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                }
            }
        }
    }
}

template <bool HasWSLeft=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_local_trans(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
                                        const int max_seqlen_k, const int row_idx_offset,
                                        const int max_seqlen_q, const int warp_row_stride,
                                        const int window_size_left, const int window_size_right) {
    

    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
    // static_assert(Layout::rank == 2, "Only support 2D Tensor");
    // const int lane_id = threadIdx.x % 64;
    // const int col_idx_offset = col_idx_offset_ + lane_id / 16;
    // const int stride_between_each_repeat = 16;
    // const int stride_between_each_thread = 4;
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    const int lane_id = threadIdx.x % 64;
    const int col_idx_offset = col_idx_offset_ + (lane_id / 16) * 4;
    const int stride_between_each_repeat = 16;
    const int stride_between_each_thread = 1;

    if constexpr (HasWSLeft) {
        for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
            const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
            #pragma unroll
            for (int j = 0; j < size<1, 0>(tensor); ++j) {
                const int col_idx = col_idx_base + j * stride_between_each_thread;
                const int row_idx_limit_up = std::max(0, col_idx + max_seqlen_k - max_seqlen_q - window_size_left);
                const int row_idx_limit_down = std::min(max_seqlen_k, col_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    const int row_idx_base = row_idx_offset + mi * warp_row_stride;
                    const int row_idx = row_idx_base;
                    // int tidx = threadIdx.x;
                    // if (tidx < 64)
                    // {
                    //     printf("col_idx = %d row_idx_limit_up = %d row_idx_limit_down = %d\n", col_idx, row_idx_limit_up, row_idx_limit_down);
                    // }
                    if (row_idx < row_idx_limit_up || row_idx >= row_idx_limit_down) {
                        tensor(mi, make_coord(j, nj)) = -INFINITY;
                    }
                }
            
            }
        }
    }
    else {
        #pragma unroll
        for (int mi = 0; mi < size<0>(tensor); ++mi) {
            const int row_idx_base = row_idx_offset + mi * warp_row_stride;
            const int row_idx = row_idx_base;
            const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_q - max_seqlen_k - window_size_left);
            const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_q - max_seqlen_k + window_size_right);
            #pragma unroll
            for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                const int col_idx_base = col_idx_offset + nj * stride_between_each_repeat;
                #pragma unroll  
                for (int j = 0; j < size<1, 0>(tensor); ++j) {
                    const int col_idx = col_idx_base + j * stride_between_each_thread;
                    
                    // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
                    //     printf("tid = %d col_idx_limit_left = %d col_idx_limit_right = %d col_idx = %d row_idx = %d max_seqlen_k = %d max_seqlen_q = %d\n", threadIdx.x, col_idx_limit_left, col_idx_limit_right, col_idx, row_idx, 
                    //         max_seqlen_k, max_seqlen_q);
                    // }
                    
                    if (col_idx + 1 < col_idx_limit_left) {
                        tensor(mi, make_coord(j, nj)) = -INFINITY;
                    }
                    // if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
                    //     tensor(mi, make_coord(j, nj)) = -INFINITY;
                    // }
                }
            }
        }
    }
}


template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask_causal_trans(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
                                         const int max_seqlen_k, const int row_idx_offset,
                                         const int max_seqlen_q, const int warp_row_stride) {
    // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
    apply_mask_local_trans</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
                                          max_seqlen_q, warp_row_stride, -1, 0);
}


template <bool Is_causal, bool Is_local, bool Has_alibi>
struct Mask {
    const int max_seqlen_k, max_seqlen_q;
    const int window_size_left, window_size_right;
    const float alibi_slope;

    __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
                                    const int window_size_left, const int window_size_right,
                                    const float alibi_slope=0.f)
        : max_seqlen_k(max_seqlen_k)
        , max_seqlen_q(max_seqlen_q)
        , window_size_left(window_size_left)
        , window_size_right(window_size_right)
        , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
    };

    // Causal_mask: whether this particular iteration needs causal masking
    template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
    __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
                                               const int col_idx_offset_,
                                               const int row_idx_offset,
                                               const int warp_row_stride) {
        static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
        static_assert(Layout::rank == 3, "Only support 3D Tensor");
        static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
        static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
        // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
        if constexpr (Need_masking) {
            // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
            Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
            // Do we need both row and column indices, or just column incides?
            static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
            /*
            查看acc的指令格式
            */
            // 0_15 = 0 16_31 = 1 32_47=2 48~63=4
            const int lane_id = threadIdx.x & 63;
            const int col_idx_offset = col_idx_offset_ + (lane_id >> 4);
            const int stride_between_each_repeat = 16;
            const int stride_between_each_thread = 4;
            if constexpr (Col_idx_only) {
                #pragma unroll
                for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                    // 沿着N方向重复,间隔为16
                    const int col_idx_base = col_idx_offset + (nj << 4);
                    #pragma unroll
                    for (int j = 0; j < size<1, 0>(tensor); ++j) {
                        /*
                        每个线程4个元素,其间隔为4
                        因为格式是
                        t0 t16 t32 t48 | t0 t16 t32 t48
                        */
                        const int col_idx = col_idx_base + (j << 2) ;
                        #pragma unroll
                        for (int mi = 0; mi < size<0>(tensor); ++mi) {
                            // No causal, no local
                            if constexpr (Has_alibi) {
                                tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                            }
                            if constexpr (!Is_even_MN) {
                                if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
                            }
                        }
                    }
                }
            } else {
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    const int row_idx = row_idx_offset + mi * warp_row_stride;
                    const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
                    const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
                    #pragma unroll
                    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                        const int col_idx_base = col_idx_offset + (nj << 4);
                        #pragma unroll
                        for (int j = 0; j < size<1, 0>(tensor); ++j) {
                            // t0的第0个元素与t0的第1个元素间隔4
                            const int col_idx = col_idx_base + (j << 2);
                            if constexpr (Has_alibi) {
                                if constexpr (Is_causal) {
                                    tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                                } else {
                                    tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
                                }
                            }

                            if constexpr (Causal_mask) {
                                if (col_idx >= col_idx_limit_right) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                                // else {
                                //     if constexpr (!Has_alibi && !Is_local) {
                                //         return;
                                //     }
                                // }
                            }
                            if constexpr (Is_local) {
                                if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                            }
                            // #if 1
                            // if (cute::thread0())
                            // {
                            //     printf("in mask Is_even_MN = %d\n", Is_even_MN);
                            // }
                            // #enfif
                            // if causal情况下mn也不是整数
                            if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
                                // Causal and Local already handles MN masking
                                if (col_idx >= max_seqlen_k) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                            }
                        }
                    }
                }
                // #pragma unroll
                // for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
                //     const int row_idx_base = row_idx_offset + mi * warp_row_stride;
                //     #pragma unroll
                //     for (int i = 0; i < size<0, 0>(tensor); ++i) {
                //         const int row_idx = row_idx_base + i * 8;
                //         const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
                //         const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
                //         #pragma unroll
                //         for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                //             const int col_idx_base = col_idx_offset + nj * 8;
                //             #pragma unroll
                //             for (int j = 0; j < size<1, 0>(tensor); ++j) {
                //                 const int col_idx = col_idx_base + j;
                //                 if constexpr (Has_alibi) {
                //                     if constexpr (Is_causal) {
                //                         tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
                //                     } else {
                //                         tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);

                //                     }
                //                 }
                //                 if constexpr (Causal_mask) {
                //                     if (col_idx >= col_idx_limit_right) {
                //                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                //                     }
                //                 }
                //                 if constexpr (Is_local) {
                //                     if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
                //                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                //                     }
                //                 }
                //                 if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
                //                     // Causal and Local already handles MN masking
                //                     if (col_idx >= max_seqlen_k) {
                //                         tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                //                     }
                //                 }
                //             }
                //         }
                //     }
                // }
            }
        }
    };

    template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
    __forceinline__ __device__ void apply_mask_continuous(Tensor<Engine, Layout> &tensor_,
                                               const int col_idx_offset_,
                                               const int row_idx_offset,
                                               const int warp_row_stride) {
        static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
        static_assert(Layout::rank == 3, "Only support 3D Tensor");
        static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
        static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
        // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
        if constexpr (Need_masking) {
            // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
            Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
            // Do we need both row and column indices, or just column incides?
            static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
            /*
            查看acc的指令格式
            */
            // 0_15 = 0 16_31 = 4 32_47=8 48~63=12
            const int lane_id = threadIdx.x % 64;
            const int col_idx_offset = col_idx_offset_ + ((lane_id >> 4) << 2);
            const int stride_between_each_repeat = 16;
            const int stride_between_each_thread = 4;
            if constexpr (Col_idx_only) {
                #pragma unroll
                for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                    // 沿着N方向重复,间隔为16
                    const int col_idx_base = col_idx_offset + (nj << 4);
                    #pragma unroll
                    for (int j = 0; j < size<1, 0>(tensor); ++j) {
                        /* 每个线程4个元素,其间隔为1
                        t0 t1 t2 t3 | t4 t5 t6 t7 */
                        const int col_idx = col_idx_base + j;
                        #pragma unroll
                        for (int mi = 0; mi < size<0>(tensor); ++mi) {
                            // No causal, no local
                            if constexpr (Has_alibi) {
                                tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                            }
                            if constexpr (!Is_even_MN) {
                                if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
                            }
                        }
                    }
                }
            } else {
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    const int row_idx = row_idx_offset + mi * warp_row_stride;
                    const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
                    const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
                    #pragma unroll
                    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                        const int col_idx_base = col_idx_offset + (nj << 4);
                        #pragma unroll
                        for (int j = 0; j < size<1, 0>(tensor); ++j) {
                            // t0的第0个元素与t0的第1个元素间隔1
                            const int col_idx = col_idx_base + j;
                            if constexpr (Has_alibi) {
                                if constexpr (Is_causal) {
                                    tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                                } else {
                                    tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
                                }
                            }

                            if constexpr (Causal_mask) {
                                if (col_idx >= col_idx_limit_right) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                                // else {
                                //     if constexpr (!Has_alibi && !Is_local) {
                                //         return;
                                //     }
                                // }
                            }
                            if constexpr (Is_local) {
                                if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                            }
                            // #if 1
                            // if (cute::thread0())
                            // {
                            //     printf("in mask Is_even_MN = %d\n", Is_even_MN);
                            // }
                            // #enfif
                            // if causal情况下mn也不是整数
                            if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
                                // Causal and Local already handles MN masking
                                if (col_idx >= max_seqlen_k) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                            }
                        }
                    }
                }
            }
        }
    };
	
	
	 template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
    __forceinline__ __device__ void apply_mask_continuous_fp8(Tensor<Engine, Layout> &tensor_,
                                               const int col_idx_offset_,
                                               const int row_idx_offset,
                                               const int warp_row_stride) {
        static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
        static_assert(Layout::rank == 3, "Only support 3D Tensor");
        static_assert(decltype(size<0>(tensor_))::value == 8, "First dimension must be 8");
        static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
        // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
        if constexpr (Need_masking) {
            // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
            Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
            // Do we need both row and column indices, or just column incides?
            static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
            /*
            查看acc的指令格式
            */
            // 0_15 = 0 16_31 = 4 32_47=8 48~63=12
            const int lane_id = threadIdx.x % 64;
            const int col_idx_offset = col_idx_offset_ + ((lane_id >> 4) << 3);
            if constexpr (Col_idx_only) {
                #pragma unroll
                for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {//2
                    // 沿着N方向重复,间隔为16
                    const int col_idx_base = col_idx_offset + (nj << 5);
                    #pragma unroll
                    for (int j = 0; j < size<1, 0>(tensor); ++j) {//8
                        /* 每个线程8个元素,其间隔为1
                        t0 t1 t2 t3 | t4 t5 t6 t7 */
                        const int col_idx = col_idx_base + j;
                        #pragma unroll
                        for (int mi = 0; mi < size<0>(tensor); ++mi) {
                            // No causal, no local
                            if constexpr (Has_alibi) {
                                tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                            }
                            if constexpr (!Is_even_MN) {
                                if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
                            }
                        }
                    }
                }
            } else {
                #pragma unroll
                for (int mi = 0; mi < size<0>(tensor); ++mi) {
                    const int row_idx = row_idx_offset + mi * warp_row_stride;
                    const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
                    const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
                    #pragma unroll
                    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {//2
                        const int col_idx_base = col_idx_offset + (nj << 5);
                        #pragma unroll
                        for (int j = 0; j < size<1, 0>(tensor); ++j) {//8
                            // t0的第0个元素与t0的第1个元素间隔1
                            const int col_idx = col_idx_base + j;
                            if constexpr (Has_alibi) {
                                if constexpr (Is_causal) {
                                    tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                                } else {
                                    tensor(mi, make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
                                }
                            }

                            if constexpr (Causal_mask) {
                                if (col_idx >= col_idx_limit_right) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                                // else {
                                //     if constexpr (!Has_alibi && !Is_local) {
                                //         return;
                                //     }
                                // }
                            }
                            if constexpr (Is_local) {
                                if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                            }
                            // #if 1
                            // if (cute::thread0())
                            // {
                            //     printf("in mask Is_even_MN = %d\n", Is_even_MN);
                            // }
                            // #enfif
                            // if causal情况下mn也不是整数
                            if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
                                // Causal and Local already handles MN masking
                                if (col_idx >= max_seqlen_k) {
                                    tensor(mi, make_coord(j, nj)) = -INFINITY;
                                }
                            }
                        }
                    }
                }
            }
        }
    };

    template <bool Causal_mask=false, bool Is_even_MN=true,
            bool Use_alibi_sqrt=false, bool Use_qq_bias=false, bool Use_mm_prefix=false,
            typename Engine, typename Layout>
    __forceinline__ __device__ void apply_mask_continuous_unified(
        Tensor<Engine, Layout> &tensor_,
        const int col_idx_offset_,
        const int row_idx_offset,
        const int warp_row_stride,
        const int context_len,
        const void * __restrict__ qq_bias_ptr = nullptr,
        const int qq_bias_stride_0 = 0,
        const int * __restrict__ mm_prefix_range_ptr = nullptr,
        const int max_mm_ranges = 0,
        const int bidb = 0,
        const float softmax_scale = 1.0f
    ) {
        static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
        static_assert(Layout::rank == 3, "Only support 3D Tensor");
        static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");

        static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local
                                        || !Is_even_MN || Use_qq_bias || Use_mm_prefix;
        if constexpr (!Need_masking) return;

        static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local 
                                        && !Causal_mask && !Use_mm_prefix && !Use_qq_bias
                                        && !(Has_alibi && Use_alibi_sqrt);  // 新增          

        Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));

        const int lane_id = threadIdx.x % 64;
        const int col_idx_offset = col_idx_offset_ + ((lane_id >> 4) << 2);

        if constexpr (Col_idx_only) {
            #pragma unroll
            for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                const int col_idx_base = col_idx_offset + (nj << 4);
                #pragma unroll
                for (int j = 0; j < size<1, 0>(tensor); ++j) {
                    const int col_idx = col_idx_base + j;
                    #pragma unroll
                    for (int mi = 0; mi < size<0>(tensor); ++mi) {
                        if constexpr (Has_alibi) {
                            // causal alibi:slope * col_idx
                            tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
                        }
                        if constexpr (!Is_even_MN) {
                            if (col_idx >= max_seqlen_k) {
                                tensor(mi, make_coord(j, nj)) = -INFINITY;
                            }
                        }
                    }
                }
            }
        } else {
            #pragma unroll
            for (int mi = 0; mi < size<0>(tensor); ++mi) {
                const int row_idx = row_idx_offset + mi * warp_row_stride;
                const int query_abs_pos = row_idx + (max_seqlen_k - max_seqlen_q);


                const int col_idx_limit_left = std::max(0,
                    row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
                const int col_idx_limit_right = std::min(max_seqlen_k,
                    row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);

                #pragma unroll
                for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
                    const int col_idx_base = col_idx_offset + (nj << 4);
                    #pragma unroll
                    for (int j = 0; j < size<1, 0>(tensor); ++j) {
                        const int col_idx = col_idx_base + j;

                        bool is_masked = false;

                        if constexpr (Causal_mask) {
                            is_masked |= (col_idx >= col_idx_limit_right);
                        }
                        if constexpr (Is_local) {
                            is_masked |= (col_idx >= col_idx_limit_right
                                    || col_idx < col_idx_limit_left);
                        }
                        if constexpr (!Is_even_MN) {
                            if constexpr (!Causal_mask && !Is_local) {
                                // causal/local 已经处理了边界,这里只处理纯边界情况
                                is_masked |= (col_idx >= max_seqlen_k);
                            }
                        }

                        if constexpr (Use_mm_prefix) {
                            bool in_bidirectional = false;
                            #pragma unroll
                            for (int i = 0; i < max_mm_ranges; ++i) {
                                const int range_start = mm_prefix_range_ptr[
                                    bidb * max_mm_ranges * 2 + i * 2];
                                const int range_end = mm_prefix_range_ptr[
                                    bidb * max_mm_ranges * 2 + i * 2 + 1];
                                const bool is_valid = (range_start < range_end);
                                const bool q_in_range = is_valid
                                                    && (query_abs_pos >= range_start)
                                                    && (query_abs_pos <= range_end);
                                const bool k_in_range = is_valid
                                                    && (col_idx >= range_start)
                                                    && (col_idx <= range_end);
                                in_bidirectional |= (q_in_range && k_in_range);
                            }
                            if (in_bidirectional) is_masked = false;
                        }

                        // 写入 -inf 并跳过后续计算
                        if (is_masked) {
                            tensor(mi, make_coord(j, nj)) = -INFINITY;
                            continue;
                        }

                        if constexpr (Has_alibi) {
                            if constexpr (Use_alibi_sqrt) {
                                // 对应 triton:-sqrt(max(0, query_abs_pos - seq_offset))
                                const float rel = float(query_abs_pos - col_idx);
                                const float alibi_offset = rel >= 0.f ? -sqrtf(rel) : 0.f;
                                tensor(mi, make_coord(j, nj)) += alibi_slope * alibi_offset;
                            } else {
                                // 对应 triton:alibi_offset = seq_offset - context_len
                                tensor(mi, make_coord(j, nj)) +=
                                    alibi_slope * (col_idx - context_len);
                            }
                        }

                        if constexpr (Use_qq_bias) {
                            const int query_pos = row_idx;
                            const int key_rel_pos = col_idx - context_len;
                            if (query_pos >= 0 && query_pos < max_seqlen_q &&
                                key_rel_pos >= 0 && key_rel_pos < qq_bias_stride_0) {
                                float bias_val = reinterpret_cast<const float*>(qq_bias_ptr)
                                    [query_pos * qq_bias_stride_0 + key_rel_pos];
                                tensor(mi, make_coord(j, nj)) += bias_val / softmax_scale;
                            }
                        }
                    }
                }
            }
        }
    };

};

} // namespace flash