sampling_topp_kernels.cu 63.9 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
19
#elif (CUDART_VERSION >= 11000)
Li Zhang's avatar
Li Zhang committed
20
21
#include <cub/cub.cuh>
#else
xiabo's avatar
xiabo committed
22
23
// #include "3rdparty/cub/cub.cuh"
#include <cub/cub.cuh>
Li Zhang's avatar
Li Zhang committed
24
25
#endif

lvhan028's avatar
lvhan028 committed
26
27
28
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
29
30
31
32

constexpr int   ENABLE_SINGLE_PASS_TOP_P = 0;
constexpr float SINGLE_PASS_THRESHOLD    = 0.9;

lvhan028's avatar
lvhan028 committed
33
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429

namespace segmented_topp_impl {

template<int HALF_ELEMENTS_PER_WARP_LOAD>
using Copy_half_t = typename std::conditional<
    HALF_ELEMENTS_PER_WARP_LOAD == 32,
    half,
    typename std::conditional<HALF_ELEMENTS_PER_WARP_LOAD == 64,
                              int,
                              typename std::conditional<HALF_ELEMENTS_PER_WARP_LOAD == 128, int2, int4>::type>::type>::
    type;

template<typename T, int ELEMENTS_PER_WARP_LOAD>
using Copy_t = Copy_half_t<sizeof(T) / sizeof(half) * ELEMENTS_PER_WARP_LOAD>;

template<typename T>
struct Float_as_int_ {
};
template<>
struct Float_as_int_<float> {
    using Type = uint32_t;
};
template<>
struct Float_as_int_<__half> {
    using Type = uint16_t;
};

using kernel_params_float   = Segmented_topk_kernel_params<float, int32_t, 256, 2>;
using kernel_params_float_1 = Segmented_topk_kernel_params<float, int32_t, 256, 1>;
using kernel_params_half    = Segmented_topk_kernel_params<__half, int32_t, 256, 4>;
using kernel_params_half_1  = Segmented_topk_kernel_params<__half, int32_t, 256, 1>;

///////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ float to_float(uint32_t src)
{
    return __int_as_float(src);
}

///////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ float to_float(uint16_t src)
{
    __half dst = __ushort_as_half(src);
    return __half2float(dst);
}

///////////////////////////////////////////////////////////////////////////////////////////////////

// sort one segment per cta
template<typename T_SCORE, int BLOCK_THREADS, int ELEMENTS_PER_THREAD>
__global__ void blockSortKernel(const T_SCORE* d_keys_in,
                                T_SCORE*       d_keys_out,
                                const int32_t* d_values_in,
                                int32_t*       d_values_out,
                                const int32_t* active_counts,
                                int            num_items_,
                                int            stride_items,
                                int            num_segments)
{
    // Specialize BlockRadixSort for a 1D block
    typedef cub::BlockRadixSort<T_SCORE, BLOCK_THREADS, ELEMENTS_PER_THREAD, int32_t> BlockRadixSort;

    // Allocate shared memory for BlockRadixSort
    __shared__ typename BlockRadixSort::TempStorage temp_storage;

    if (blockIdx.x >= num_segments) {
        return;
    }

    int num_items = active_counts[blockIdx.x];  // > num_items_ ? num_items_ : active_counts[blockIdx.x];

    if (num_items == 0) {
        return;
    }

    // Obtain a segment of consecutive items that are blocked across threads
    T_SCORE thread_keys[ELEMENTS_PER_THREAD];
    int32_t thread_values[ELEMENTS_PER_THREAD];

    int32_t block_offset = blockIdx.x * stride_items;
    cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_out + block_offset, thread_keys, num_items, 0);
    cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_values_out + block_offset, thread_values, num_items, -1);
    __syncthreads();

    // Collectively sort the keys and values among block threads
    BlockRadixSort(temp_storage).SortDescendingBlockedToStriped(thread_keys, thread_values);

    // Store output in striped fashion
    cub::StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_out + block_offset, thread_keys, num_items);
    cub::StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_values_out + block_offset, thread_values, num_items);
}

///////////////////////////////////////////////////////////////////////////////////////////////////

/// block sort kernel
template<typename T_SCORE>
void blockSort(const T_SCORE* d_keys_in,
               T_SCORE*       d_keys_out,
               const int32_t* d_values_in,
               int32_t*       d_values_out,
               const int32_t* active_counts,
               int            num_items,
               int            stride_items,
               int            num_segments,
               cudaStream_t   stream)
{
    if (num_items == 0) {
        return;
    }

    int kernel_index  = div_up(num_items, 128) - 1;
    int warps_per_cta = (kernel_index + 1) * 128 / 32;
    if (kernel_index > 7) {
        kernel_index  = 7 + div_up(num_items, 1024) - 1;
        warps_per_cta = 1024 / 32;
    }
    assert(warps_per_cta <= 32);

    dim3 block(warps_per_cta * 32);
    dim3 grid(num_segments);

    using kernel_func = void (*)(const T_SCORE* d_keys_in,
                                 T_SCORE*       d_keys_out,
                                 const int32_t* d_values_in,
                                 int32_t*       d_values_out,
                                 const int32_t* active_counts,
                                 int            num_items,
                                 int            stride_items,
                                 int            num_segments);

    static const kernel_func kernel_funcs[] = {
        &blockSortKernel<T_SCORE, 128, 1>,
        &blockSortKernel<T_SCORE, 256, 1>,
        &blockSortKernel<T_SCORE, 384, 1>,
        &blockSortKernel<T_SCORE, 512, 1>,
        &blockSortKernel<T_SCORE, 640, 1>,
        &blockSortKernel<T_SCORE, 768, 1>,
        &blockSortKernel<T_SCORE, 896, 1>,
        &blockSortKernel<T_SCORE, 1024, 1>,
        &blockSortKernel<T_SCORE, 1024, 2>,
        &blockSortKernel<T_SCORE, 1024, 4>,
        //&blockSortKernel<T_SCORE, 1024, 6>,
    };
    kernel_funcs[kernel_index]<<<grid, block, 0, stream>>>(
        d_keys_in, d_keys_out, d_values_in, d_values_out, active_counts, num_items, stride_items, num_segments);
}

///////////////////////////////////////////////////////////////////////////////////////////////////

struct BlockPrefixCallbackOp {
    // Running prefix
    int running_total;
    // Constructor
    __device__ BlockPrefixCallbackOp(uint32_t running_total): running_total(running_total) {}
    // Callback operator to be entered by the first warp of threads in the block.
    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
    __device__ int operator()(uint32_t block_aggregate)
    {
        uint32_t old_prefix = running_total;
        running_total += block_aggregate;
        return old_prefix;
    }
};

///////////////////////////////////////////////////////////////////////////////////////////////////

#define DO_DEBUG_PRINT 0

// governs the split between regs and smem
constexpr float SMEM_FRACTION = 0.5F;
constexpr float P_EPSILON     = 0.01F;

constexpr int MAX_TOP_K = 3072;
constexpr int WARP_SZ   = 32;

template<typename Kernel_params, int ITEMS_PER_THREAD>
__global__ __launch_bounds__(Kernel_params::BLOCK_THREADS,
                             1) void segmented_top_p_single_pass(TopKPerSegmentParams params)
{
#if DO_DEBUG_PRINT
    constexpr int debug_block_id = 26;
#endif

    using Key_Data_Type     = typename Kernel_params::Key_Data_Type;
    using Int_Key_Data_Type = typename Float_as_int_<Key_Data_Type>::Type;

    // 4 fp16 keys or 2 fp32 keys
    constexpr int                                         KEYS_PER_LDG = Kernel_params::KEYS_PER_LDG;
    typedef Copy_t<Key_Data_Type, WARP_SZ * KEYS_PER_LDG> copy_t;
    union access_t {
        copy_t            v;
        Int_Key_Data_Type x[KEYS_PER_LDG];  // supported size 1,2,4
    };

    constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;

    constexpr int ITEMS_PER_THREAD_IN_REGS = ITEMS_PER_THREAD * (1.0F - SMEM_FRACTION);
    constexpr int ITEMS_PER_THREAD_IN_SMEM = ITEMS_PER_THREAD - ITEMS_PER_THREAD_IN_REGS;

#if DO_DEBUG_PRINT == 1
    if (blockIdx.x == 0 && threadIdx.x == 0) {
        printf("ITEMS_PER_THREAD, ITEMS_PER_THREAD_IN_REGS, ITEMS_PER_THREAD_IN_SMEM = %d, %d, %d\n",
               ITEMS_PER_THREAD,
               ITEMS_PER_THREAD_IN_REGS,
               ITEMS_PER_THREAD_IN_SMEM);
    }
#endif

    constexpr int          MIN_KEY            = 0;
    constexpr int          ENABLED_PER_THREAD = (ITEMS_PER_THREAD + 32 - 1) / 32;
    extern __shared__ int2 dynamic_smem[];
    int2*                  smem_selected_elements = dynamic_smem;
    Int_Key_Data_Type*     smem_thread_items = reinterpret_cast<Int_Key_Data_Type*>(smem_selected_elements + MAX_TOP_K);

    __shared__ unsigned int smem_selected_count;

    // Specialize BlockScan type for our thread block
    typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;

    // Specialize BlockScan type for our thread block
    typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce;
    __shared__ float                               smem_p_sum_total;

    __shared__ union {
        typename BlockScan::TempStorage scan;

        typename BlockReduce::TempStorage reduce;
    } temp_storage;
    // Initialize running total
    BlockPrefixCallbackOp prefix_op(0);

    unsigned int old_selected_count;

    uint32_t segment = blockIdx.y * gridDim.x + blockIdx.x;

    // Preceding TopK has shortcutted this segment
    if (params.gmem_begin_offsets[segment] == params.gmem_end_offsets[segment]) {
        if (threadIdx.x == 0) {
            params.gmem_active_count_per_segment[segment] = 1;
            atomicMax(params.gmem_active_count_total, 1);
        }
        return;
    }

    Int_Key_Data_Type* gmem_src_keys = reinterpret_cast<Int_Key_Data_Type*>(params.gmem_src_keys);
    Int_Key_Data_Type* gmem_dst_keys = reinterpret_cast<Int_Key_Data_Type*>(params.gmem_dst_keys);
    int32_t*           gmem_dst_vals = reinterpret_cast<int32_t*>(params.gmem_dst_vals);

    constexpr int BITS_IN_KEY = sizeof(Key_Data_Type) * 8;

    int items       = params.num_items / params.num_segments;
    int first_index = segment * items;
    gmem_src_keys += first_index;
    gmem_dst_keys += first_index;
    gmem_dst_vals += first_index;

    int               index_limit                            = items;
    Int_Key_Data_Type thread_items[ITEMS_PER_THREAD_IN_REGS] = {0};

    // Load all keys into registers and smem
    const int     lane_id   = threadIdx.x % WARP_SZ;
    const int     warp_id   = threadIdx.x / WARP_SZ;
    constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SZ;

    access_t ZERO;
    for (int i = 0; i < KEYS_PER_LDG; i++) {
        ZERO.x[i] = MIN_KEY;
    }

    // registers
    for (int iter = 0; iter < ITEMS_PER_THREAD_IN_REGS; iter++) {
        int offset         = (iter + threadIdx.x * ITEMS_PER_THREAD);
        thread_items[iter] = (offset < index_limit) ? gmem_src_keys[offset] : MIN_KEY;
    }

    // shared memory
    for (int c = warp_id; c < BLOCK_THREADS; c += NUM_WARPS) {
        for (int iter = lane_id * KEYS_PER_LDG; iter < ITEMS_PER_THREAD_IN_SMEM; iter += WARP_SZ * KEYS_PER_LDG) {
            int      offset = iter + c * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS;
            access_t val;
            val.v = (offset < index_limit) ? *reinterpret_cast<copy_t*>(&gmem_src_keys[offset]) : ZERO.v;
            for (int i = 0; i < KEYS_PER_LDG; i++) {
                smem_thread_items[c + (iter + i) * BLOCK_THREADS] = val.x[i];
            }
            // smem_thread_items[c  + iter * BLOCK_THREADS] = (offset < index_limit)? gmem_src_keys[offset] : MIN_KEY;
        }
    }

    Int_Key_Data_Type select_mask = 0;
    Int_Key_Data_Type save_mask   = 0;

    // Int_Key_Data_Type save_bit = 0;
    // set to true when we finish with too few keys, so we go back to last_save_mask one more time
    bool is_last_iter = false;

    if (threadIdx.x == 0) {
        smem_selected_count = 0;
        old_selected_count  = 0;
    }

    // iterate over bits.
    // skip the first two bits,
    // * bit 31 is the sign bit. all values are positive
    // * bit 30 is only set for values >= 2, but the input consists only of values in
    // the range of [0,1]
    constexpr int               START_BIT = BITS_IN_KEY - 1;
    constexpr int               SKIP_BITS = 2;
    constexpr Int_Key_Data_Type ONE       = (Int_Key_Data_Type)1;
    uint32_t                    selected;
    uint32_t                    sc;
    float                       p_sum_total     = 0.0F;
    float                       old_p_sum_total = 0.0F;
    uint32_t                    offset          = 0;
    for (Int_Key_Data_Type bit = START_BIT - SKIP_BITS; true; --bit) {
        __syncthreads();
        Int_Key_Data_Type bit_mask = select_mask | (ONE << bit);

        uint32_t enabled[ENABLED_PER_THREAD] = {0};
        float    thread_sum                  = 0.0F;

        for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) {
            // check if all the bits from bit mask are contained in the thread_item. If yes, set respective
            // bit of enabled
            auto     val        = thread_items[item];
            uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0);
            // thread_sum += (is_enabled)? to_float(val) : 0.0F;
            thread_sum += is_enabled * to_float(val);
            enabled[item / 32] |= is_enabled << (item % 32);
        }

        for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) {
            int idx = threadIdx.x + item * BLOCK_THREADS;
            // int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x;
            auto     val        = smem_thread_items[idx];
            uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0);
            // thread_sum += (is_enabled)? to_float(val) : 0.0F;
            thread_sum += is_enabled * to_float(val);
            enabled[(ITEMS_PER_THREAD_IN_REGS + item) / 32] |= is_enabled << ((ITEMS_PER_THREAD_IN_REGS + item) % 32);
        }

        selected = 0;
#pragma unroll
        for (int i = 0; i < ENABLED_PER_THREAD; i++) {
            selected += __popc(enabled[i]);
        }

        float p_sum = BlockReduce(temp_storage.reduce).Sum(thread_sum);

        if (threadIdx.x == 0) {
            p_sum_total += p_sum;
            smem_p_sum_total = p_sum_total;
        }

        __syncthreads();
        p_sum_total = smem_p_sum_total;
        __syncthreads();

        BlockScan(temp_storage.scan).ExclusiveSum(selected, offset, prefix_op);

        if (threadIdx.x == 0) {
            smem_selected_count = prefix_op.running_total;
        }

        __syncthreads();
        sc = smem_selected_count;
        __syncthreads();

        // float p_diff = params.top_p - p_sum_total;
        float p_diff = p_sum_total - params.top_p;

        if ((p_sum_total <= params.top_p + P_EPSILON && p_sum_total > 0)
            || (p_sum_total > params.top_p && sc <= MAX_TOP_K) || (bit == 0 && p_sum_total > 0) || is_last_iter) {

#if DO_DEBUG_PRINT == 1
            __syncthreads();
            if (threadIdx.x == 0 && blockIdx.x == debug_block_id) {
                sc = smem_selected_count;
                printf("bit %d bit_mask %d offset %d (%d, %d), sc = %d, p_sum = %f, p_sum_total = %f\n",
                       bit,
                       bit_mask,
                       offset,
                       blockIdx.x,
                       threadIdx.x,
                       sc,
                       p_sum,
                       p_sum_total);
            }
            __syncthreads();
#endif

            for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) {
                // last condition should not trigger with well trained weights, but we will get
                // illegal mewmory access if we do not have one in those rare cases
                if (enabled[item / 32] & (ONE << (item % 32)) && offset < MAX_TOP_K) {
                    smem_selected_elements[offset] =
                        make_int2(thread_items[item], item + threadIdx.x * ITEMS_PER_THREAD);
                    ++offset;
                    thread_items[item] = MIN_KEY;
                }
            }

            for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) {
                if (enabled[(item + ITEMS_PER_THREAD_IN_REGS) / 32] & (ONE << ((item + ITEMS_PER_THREAD_IN_REGS) % 32))
                    && offset < MAX_TOP_K) {
                    int idx = threadIdx.x + item * BLOCK_THREADS;
                    // int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x;
                    // if (idx <  params.num_items_per_segment_in_smem)
                    {
                        smem_selected_elements[offset] = make_int2(
                            smem_thread_items[idx], item + threadIdx.x * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS);
                        ++offset;
                        smem_thread_items[idx] = MIN_KEY;
                    }
                }
            }
        }

#if DO_DEBUG_PRINT == 1
        if (threadIdx.x == 0 && blockIdx.x == debug_block_id) {
            printf("!!!! bit %d bit_mask %d offset %d (%d, %d), sc = %d, p_sum = %f, p_sum_total = %f\n",
                   bit,
                   bit_mask,
                   offset,
                   blockIdx.x,
                   threadIdx.x,
                   sc,
                   p_sum,
                   p_sum_total);
        }
#endif

        if (p_diff <= P_EPSILON && p_diff >= 0 || (p_sum_total > params.top_p && sc <= MAX_TOP_K) || bit == 0) {

            break;
        }
        // p > top_p
        else if (p_diff > P_EPSILON) {
            // There are too many bits in the current selection
            // Save the current state and go to the next bit
            // If there are not enough items left using the next bit
            // it's necessary to restart here with the current bit not set
            save_mask = bit_mask;
            select_mask |= bit_mask;

            if (threadIdx.x == 0) {
                smem_selected_count = old_selected_count;
                p_sum_total         = old_p_sum_total;

                prefix_op.running_total = old_selected_count;
            }
        }
        else {
            // sc < num_top_k branch
            if (save_mask) {
                select_mask = save_mask;

                save_mask = 0;
            }
            if (threadIdx.x == 0) {
                old_selected_count = smem_selected_count;
                old_p_sum_total    = p_sum_total;
            }
        }
    }

    __syncthreads();

    // store data to global memory
    sc = (p_sum_total < params.top_p) ? params.num_items / params.num_segments : smem_selected_count;
    if (threadIdx.x == 0) {
        params.gmem_active_count_per_segment[segment] = sc;
        atomicMax(params.gmem_active_count_total, sc);
    }
    if (sc >= MAX_TOP_K) {
        return;
    }
    for (int i = threadIdx.x; i < sc; i += blockDim.x) {
        int2 selected_element = smem_selected_elements[i];
        gmem_dst_keys[i]      = selected_element.x;
        gmem_dst_vals[i]      = selected_element.y;
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_params>
int getSmemSizeAndCheck(const TopKPerSegmentContext& context, const TopKPerSegmentParams& params)
{
    constexpr int BLOCK_THREADS         = Kernel_params::BLOCK_THREADS;
    using Key_Data_Type                 = typename Kernel_params::Key_Data_Type;
    int           num_items_per_segment = params.num_items / params.num_segments;
    constexpr int ITEMS_INCREMENT       = Kernel_params::ITEMS_INCREMENT;
    int           kernel_index          = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1;

    int       smem_size                = MAX_TOP_K * sizeof(int2);
    const int items_per_thread         = (kernel_index + 1) * ITEMS_INCREMENT;
    const int items_per_thread_in_regs = items_per_thread * (1.0F - SMEM_FRACTION);
    const int items_per_thread_in_smem = items_per_thread - items_per_thread_in_regs;

    smem_size += items_per_thread_in_smem * BLOCK_THREADS * sizeof(typename Float_as_int_<Key_Data_Type>::Type);

    int keys_per_ldg = 2 * sizeof(Key_Data_Type) / 2;
    if (smem_size + BLOCK_THREADS * sizeof(float) > (size_t)context.sm_shared_size ||  // dynamic + static memory
        items_per_thread_in_regs + items_per_thread_in_smem != items_per_thread || params.top_p + P_EPSILON > 1.0F
        || items_per_thread_in_regs % keys_per_ldg != 0 || items_per_thread_in_smem % keys_per_ldg != 0
        || num_items_per_segment % keys_per_ldg != 0) {
        return -1;
    }

    return smem_size;
}

///////////////////////////////////////////////////////////////////////////////////////////////////

int getSmemSizeAndCheck(const TopKPerSegmentContext& context,
                        const TopKPerSegmentParams&  params,
                        const DType_t                DT_SCORE)
{
    int num_items_per_segment = params.num_items / params.num_segments;
    if (DT_SCORE == kFLOAT) {
        if (num_items_per_segment % 2 == 0) {
            return getSmemSizeAndCheck<kernel_params_float>(context, params);
        }
        else {
            return getSmemSizeAndCheck<kernel_params_float_1>(context, params);
        }
    }
    else {
        if (num_items_per_segment % 4 == 0) {
            return getSmemSizeAndCheck<kernel_params_half>(context, params);
        }
        else {
            return getSmemSizeAndCheck<kernel_params_half_1>(context, params);
        }
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_params>
void segmentedTopPSinglePass_dispatch(const TopKPerSegmentParams&  params,
                                      const TopKPerSegmentContext& context,
                                      cudaStream_t                 stream)
{

    constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
    using Key_Data_Type         = typename Kernel_params::Key_Data_Type;
    using Value_Data_Type       = typename Kernel_params::Value_Data_Type;

    int num_items_per_segment = params.num_items / params.num_segments;

    constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT;
    int           kernel_index    = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1;

#define KERNEL_RUN(INDEX)                                                                                              \
    {                                                                                                                  \
        if (smem_size > 0)                                                                                             \
            check_cuda_error(                                                                                          \
                cudaFuncSetAttribute(segmented_top_p_single_pass<Kernel_params, ITEMS_INCREMENT*(INDEX + 1)>,          \
                                     cudaFuncAttributeMaxDynamicSharedMemorySize,                                      \
                                     smem_size));                                                                      \
        segmented_top_p_single_pass<Kernel_params, ITEMS_INCREMENT*(INDEX + 1)>                                        \
            <<<grid_dim, Kernel_params::BLOCK_THREADS, smem_size, stream>>>(params);                                   \
    }

    int smem_size = getSmemSizeAndCheck<Kernel_params>(context, params);

    dim3 grid_dim(params.num_segments, 1);

    switch (kernel_index) {
        case 0:
            KERNEL_RUN(0) break;
        case 1:
            KERNEL_RUN(1) break;
        case 2:
            KERNEL_RUN(2) break;
        case 3:
            KERNEL_RUN(3) break;
        case 4:
            KERNEL_RUN(4) break;
        case 5:
            KERNEL_RUN(5) break;
        case 6:
            KERNEL_RUN(6) break;
        case 7:
            KERNEL_RUN(7) break;
        default:
            exit(1);
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_params>
void topPPerSegment_dispatch(const TopKPerSegmentContext& context,
                             TopKPerSegmentParams&        params,
                             void*                        temp_storage,
                             size_t&                      temp_storage_bytes,
                             cudaStream_t                 stream)
{

    using Key_Data_Type   = typename Kernel_params::Key_Data_Type;
    using Value_Data_Type = typename Kernel_params::Value_Data_Type;

    if (temp_storage == nullptr) {
        if (params.num_segments > 1) {
            cub::DeviceSegmentedRadixSort::SortPairsDescending(temp_storage,
                                                               temp_storage_bytes,
                                                               reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
                                                               reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
                                                               reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
                                                               reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
                                                               params.num_items,
                                                               params.num_segments,
                                                               params.gmem_begin_offsets,
                                                               params.gmem_end_offsets,
                                                               0,
                                                               sizeof(Key_Data_Type) * 8,
                                                               stream);
        }
        else {
            cub::DeviceRadixSort::SortPairsDescending(temp_storage,
                                                      temp_storage_bytes,
                                                      reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
                                                      reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
                                                      reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
                                                      reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
                                                      params.num_items,
                                                      0,
                                                      sizeof(Key_Data_Type) * 8,
                                                      stream);
        }
        temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256;
        // total active counts
        temp_storage_bytes += div_up(sizeof(int), 256) * 256;
        // storage for gmem_end_offsets
        temp_storage_bytes += div_up(sizeof(int) * params.num_segments, 256) * 256;
        return;
    }

    size_t cub_temp_storage_bytes =
        temp_storage_bytes - div_up(sizeof(int), 256) * 256 - div_up(sizeof(int) * params.num_segments, 256) * 256;
    void* cub_temp_storage         = temp_storage;
    params.gmem_active_count_total = reinterpret_cast<int*>((char*)temp_storage + cub_temp_storage_bytes);
    params.gmem_active_count_per_segment =
        reinterpret_cast<int*>((char*)params.gmem_active_count_total + div_up(sizeof(int), 256) * 256);

    int num_items_per_segment = params.num_items / params.num_segments;

    cudaMemsetAsync(params.gmem_active_count_total, 0, sizeof(int), stream);
    cudaMemsetAsync(params.gmem_dst_keys, 0, params.num_items * sizeof(Key_Data_Type), stream);
    segmentedTopPSinglePass_dispatch<Kernel_params>(params, context, stream);

    int max_num_items = 0;
    cudaMemcpyAsync(&max_num_items, params.gmem_active_count_total, sizeof(int), cudaMemcpyDeviceToHost, stream);

    cudaStreamSynchronize(stream);

    if (max_num_items >= MAX_TOP_K || max_num_items == 0) {
        if (params.num_segments > 1) {
            cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage,
                                                               cub_temp_storage_bytes,
                                                               reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
                                                               reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
                                                               reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
                                                               reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
                                                               params.num_items,
                                                               params.num_segments,
                                                               params.gmem_begin_offsets,
                                                               params.gmem_end_offsets,
                                                               0,
                                                               sizeof(Key_Data_Type) * 8,
                                                               stream);
        }
        else {
            cub::DeviceRadixSort::SortPairsDescending(cub_temp_storage,
                                                      cub_temp_storage_bytes,
                                                      reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
                                                      reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
                                                      reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
                                                      reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
                                                      params.num_items,
                                                      0,
                                                      sizeof(Key_Data_Type) * 8,
                                                      stream);
        }
    }
    else {
        // run at max supported value
        blockSort<Key_Data_Type>((const Key_Data_Type*)(params.gmem_dst_keys),
                                 (Key_Data_Type*)(params.gmem_dst_keys),
                                 (const Value_Data_Type*)(params.gmem_dst_vals),
                                 (Value_Data_Type*)(params.gmem_dst_vals),
                                 params.gmem_active_count_per_segment,
                                 max_num_items,
                                 num_items_per_segment,
                                 params.num_segments,
                                 stream);
    }
}

///////////////////////////////////////////////////////////////////////////////////////////////////

int topPPerSegment(const TopKPerSegmentContext& context,
                   TopKPerSegmentParams&        params,
                   const DType_t                DT_SCORE,
                   void*                        temp_storage,
                   size_t&                      temp_storage_bytes,
                   cudaStream_t                 stream)
{
    int num_items_per_segment = params.num_items / params.num_segments;
    if (DT_SCORE == kFLOAT) {
        if (num_items_per_segment % 2 == 0) {
            topPPerSegment_dispatch<kernel_params_float>(context, params, temp_storage, temp_storage_bytes, stream);
        }
        else {
            topPPerSegment_dispatch<kernel_params_float_1>(context, params, temp_storage, temp_storage_bytes, stream);
        }
    }
    else {
        if (num_items_per_segment % 4 == 0) {
            topPPerSegment_dispatch<kernel_params_half>(context, params, temp_storage, temp_storage_bytes, stream);
        }
        else {
            topPPerSegment_dispatch<kernel_params_half_1>(context, params, temp_storage, temp_storage_bytes, stream);
        }
    }

    return 0;
}

///////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace segmented_topp_impl

__global__ void topPInitialize(
    int* topp_id_val_buf, int* topp_offset_buf, int* begin_topp_offset_buf_, const int batch_size, const int n)
{
    int tid = threadIdx.x;
    int bid = blockIdx.x;

    if (bid == 0) {
        for (int i = tid; i < batch_size + 1; i += blockDim.x) {
            topp_offset_buf[i]        = i * n;
            begin_topp_offset_buf_[i] = topp_offset_buf[i];
        }
    }

    int index = tid + bid * blockDim.x;

    while (index < batch_size * n) {
        topp_id_val_buf[index] = index % n;
        index += blockDim.x * gridDim.x;
    }
}

void invokeTopPInitialize(int*         topp_id_val_buf,
                          int*         topp_offset_buf,
                          int*         begin_topp_offset_buf_,
                          const size_t batch_size,
                          const int    n,
                          cudaStream_t stream)
{
    // n: the column number of logits_buffer for top_p sampling
    topPInitialize<<<32, 512, 0, stream>>>(topp_id_val_buf, topp_offset_buf, begin_topp_offset_buf_, batch_size, n);
}

template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void topp_beam_topk_kernel(const T*     log_probs,  // prob.
                                                                          int*         topk_tmp_id_buf,
                                                                          T*           topk_tmp_val_buf,
                                                                          const int    vocab_size,
                                                                          int*         offset_buf,
                                                                          int*         begin_offset_buf,
                                                                          const float  top_p,
                                                                          const float* top_ps,
                                                                          const bool*  skip_decode)
{
    int thread_id = threadIdx.x;
    int batch_id  = blockIdx.x;
    if (skip_decode != nullptr && skip_decode[batch_id]) {
        return;
    }
    float p_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;

    typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
    __shared__ typename BlockReduce::TempStorage               temp_storage;
    TopK<T, MAX_K>                                             partial;

    const bool IS_FP16   = std::is_same<T, half>::value;
    const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;

#pragma unroll
    for (int i = 0; i < MAX_K; ++i) {
        partial.p[i] = -1;
        partial.u[i] = -MAX_T_VAL;
    }

#pragma unroll
    for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
        int index = elem_id + batch_id * vocab_size;
        partial.insert(log_probs[index], index);
    }

    TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);

    if (thread_id == 0) {
        begin_offset_buf[batch_id] = offset_buf[batch_id];
        T sum_prob                 = (T)(0.0f);

#pragma unroll
        for (int i = 0; i < MAX_K; i++) {
            sum_prob += total.u[i];
        }

        if ((float)sum_prob >= p_threshold) {
            begin_offset_buf[batch_id] += vocab_size;
            int index = batch_id * vocab_size;

#pragma unroll
            for (int i = 0; i < MAX_K; ++i) {
                topk_tmp_id_buf[index + i]  = total.p[i] % vocab_size;
                topk_tmp_val_buf[index + i] = total.u[i];
            }
        }
    }
}

struct BlockPrefixCallbackOp {
    // Running prefix
    float running_total;
    // Constructor
    __device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}
    // Callback operator to be entered by the first warp of threads in the block.
    // Thread-0 is responsible for returning a value for seeding the block-wide scan.
    __device__ float operator()(float block_aggregate)
    {
        float old_prefix = running_total;
        running_total += block_aggregate;
        return old_prefix;
    }
};

template<typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T*             sorted_log_probs,
                              int*           sorted_id_vals,
                              int*           ids,
                              int*           sequence_length,
                              bool*          finished_buf,
                              float*         cum_log_probs,
                              float*         output_log_probs,
                              const int*     begin_offset_buf,
                              const int*     offset_buf,
                              const int      vocab_size,
                              curandState_t* curandstate,
                              const float    top_p,
                              const float*   top_ps,
                              const int*     end_ids,
                              const int      batch_size,
                              const bool*    skip_decode)
{
    __shared__ int   stop_shared;
    __shared__ float rand_num_s;

    const int tid      = threadIdx.x;
    const int batch_id = blockIdx.x;
    if (skip_decode != nullptr && skip_decode[batch_id]) {
        return;
    }

    constexpr int WARP_SIZE      = 32;
    constexpr int NUM_WARPS      = BLOCK_SIZE / WARP_SIZE;
    const int     lane_id        = threadIdx.x % WARP_SIZE;
    const int     warp_id        = threadIdx.x / WARP_SIZE;
    const float   prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;

    if (threadIdx.x == 0) {
        stop_shared = 0;
        rand_num_s  = curand_uniform(curandstate + blockIdx.x) * prob_threshold;
    }

    // if begin_offset_buf and offset_buf of sorting have same value,
    // this means that we have find best one in beam_topK_kernel_for_topP
    // and skip the sorting. So, we can skip then during sampling.
    if (begin_offset_buf[batch_id] == offset_buf[batch_id]) {
        if (tid == 0) {
            int offset    = batch_id * vocab_size;
            ids[batch_id] = sorted_id_vals[offset];

            if (cum_log_probs != nullptr || output_log_probs != nullptr) {
                float lprob = logf(sorted_log_probs[offset]);
                if (cum_log_probs != nullptr) {
                    cum_log_probs[batch_id] += lprob;
                }
                if (output_log_probs != nullptr) {
                    output_log_probs[batch_id] = lprob;
                }
            }
            if (sequence_length != nullptr && finished_buf != nullptr) {
                sequence_length[batch_id] =
                    finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
                finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0;
            }
        }
        return;
    }

    typedef cub::BlockScan<float, BLOCK_SIZE>  BlockScan;
    __shared__ typename BlockScan::TempStorage temp_storage;
    __shared__ uint32_t                        selected_shared[NUM_WARPS];
    // Initialize running total
    BlockPrefixCallbackOp prefix_op(0);

    if (lane_id == 0) {
        selected_shared[warp_id] = 0;
    }

    __syncthreads();

    int offset          = batch_id * vocab_size;
    ids[batch_id]       = sorted_id_vals[offset];
    int   end           = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
    int   i_active      = 0;
    float thread_offset = 0;
    for (int i = tid; i < end; i += BLOCK_SIZE) {
        float thread_count = (i < vocab_size) ? (float)sorted_log_probs[offset + i] : 0.f;
        BlockScan(temp_storage).InclusiveSum(thread_count, thread_offset, prefix_op);

        uint32_t active_mask = __ballot_sync(0xFFFFFFFF, rand_num_s <= thread_offset);

        i_active = i;
        if (active_mask != 0) {
            if (lane_id == 0) {
                atomicAdd(&stop_shared, 1);
                selected_shared[warp_id] = active_mask;
            }
        }
        __syncthreads();
        if (stop_shared > 0) {
            break;
        }
    };

    // select first active warp
    bool skip = (selected_shared[warp_id] > 0) ? false : true;
    for (int i = 0; i < warp_id; i++) {
        if (selected_shared[i] != 0) {
            skip = true;
        }
    }
    if (!skip) {
        int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]);
        if (lane_id == active_lane_id) {
            ids[batch_id] = sorted_id_vals[offset + i_active];
            if (cum_log_probs != nullptr || output_log_probs != nullptr) {
                float lprob = logf(sorted_log_probs[offset + i_active]);
                if (cum_log_probs != nullptr) {
                    cum_log_probs[batch_id] += lprob;
                }
                if (output_log_probs != nullptr) {
                    output_log_probs[batch_id] = lprob;
                }
            }
            if (sequence_length != nullptr && finished_buf != nullptr) {
                sequence_length[batch_id] =
                    finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
                finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0;
            }
        }
    }
}

template<typename T>
void invokeBatchTopPSampling(void*           workspace,
                             size_t&         workspace_size,
                             size_t&         cub_temp_storage_size,
                             int*            output_ids,
                             int*            sequence_length,
                             bool*           finished_buf,
                             float*          cum_log_probs,
                             float*          output_log_probs,
                             const T*        log_probs,
                             const int*      id_vals,
                             int*            offset_buf,
                             int*            begin_offset_buf,
                             curandState_t*  curandstate,
                             const int       batch_size,
                             const size_t    vocab_size_padded,
                             const int*      end_ids,
                             const float     max_top_p,
                             const float*    top_ps,
                             cudaStream_t    stream,
                             cudaDeviceProp* cuda_device_prop,
                             const bool*     skip_decode)
{
    // Here, we put batch size as an argument because the batch size of initialization
    // and inference may be different due to pipeline parallelism.
    const int vocab_size = vocab_size_padded;
    const int block_size = 256;

    size_t sorted_log_prob_buf_size = batch_size * vocab_size * sizeof(T);    // type T
    size_t sorted_id_vals_buf_size  = batch_size * vocab_size * sizeof(int);  // type int
    sorted_log_prob_buf_size        = div_up(sorted_log_prob_buf_size, 256) * 256;
    sorted_id_vals_buf_size         = div_up(sorted_id_vals_buf_size, 256) * 256;

    void* cub_temp_storage = workspace;
    T*    sorted_log_probs = (T*)((char*)cub_temp_storage + cub_temp_storage_size);
    int*  sorted_id_vals   = (int*)((char*)sorted_log_probs + sorted_log_prob_buf_size);

    bool do_radix_sort = (ENABLE_SINGLE_PASS_TOP_P == 0 || max_top_p >= SINGLE_PASS_THRESHOLD);
    int  smem_size     = -1;

    segmented_topp_impl::TopKPerSegmentContext context;
    segmented_topp_impl::TopKPerSegmentParams  params;
    segmented_topp_impl::DType_t               dataTypeKind =
        (std::is_same<T, float>::value) ? segmented_topp_impl::kFLOAT : segmented_topp_impl::kHALF;

    if (!do_radix_sort) {
        FT_CHECK(cuda_device_prop != nullptr);
        memset(&context, 0, sizeof(context));
        context.sm_count       = cuda_device_prop->multiProcessorCount;
        context.sm_shared_size = cuda_device_prop->sharedMemPerMultiprocessor;
        context.sm_version     = cuda_device_prop->major * 100 + cuda_device_prop->minor * 10;

        memset(&params, 0, sizeof(params));
        params.gmem_src_keys        = reinterpret_cast<void*>(const_cast<T*>(log_probs));
        params.gmem_dst_keys        = sorted_log_probs;
        params.gmem_src_vals        = reinterpret_cast<void*>(const_cast<int*>(id_vals));
        params.gmem_dst_vals        = reinterpret_cast<void*>(sorted_id_vals);
        params.gmem_begin_offsets   = begin_offset_buf;
        params.gmem_end_offsets     = offset_buf + 1;
        params.workspace            = nullptr;
        params.num_items            = vocab_size * batch_size;
        params.num_segments         = batch_size;
        params.top_p                = max_top_p;
        params.confidence_threshold = 0.0F;

        smem_size     = getSmemSizeAndCheck(context, params, dataTypeKind);
        do_radix_sort = smem_size < 0;
    }

    if (do_radix_sort) {
        if (workspace == nullptr) {
            check_cuda_error(
                cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
                                                                   cub_temp_storage_size,
                                                                   log_probs,
                                                                   (T*)nullptr,
                                                                   id_vals,
                                                                   (int*)nullptr,
                                                                   vocab_size * batch_size,
                                                                   batch_size,
                                                                   begin_offset_buf,
                                                                   offset_buf + 1,
                                                                   0,              // begin_bit
                                                                   sizeof(T) * 8,  // end_bit = sizeof(KeyT) * 8
                                                                   stream));       // cudaStream_t
            cub_temp_storage_size = div_up(cub_temp_storage_size, 256) * 256;
            workspace_size        = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size;
            return;
        }

        topp_beam_topk_kernel<T, 1, block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
                                                                                       sorted_id_vals,
                                                                                       sorted_log_probs,
                                                                                       vocab_size,
                                                                                       offset_buf,
                                                                                       begin_offset_buf,
                                                                                       max_top_p,
                                                                                       top_ps,
                                                                                       skip_decode);

        check_cuda_error(
            cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage,
                                                               cub_temp_storage_size,
                                                               log_probs,
                                                               sorted_log_probs,
                                                               id_vals,
                                                               sorted_id_vals,
                                                               vocab_size * batch_size,
                                                               batch_size,
                                                               begin_offset_buf,
                                                               offset_buf + 1,
                                                               0,              // begin_bit
                                                               sizeof(T) * 8,  // end_bit = sizeof(KeyT) * 8
                                                               stream));       // cudaStream_t
    }
    else {
        if (workspace == nullptr) {
            segmented_topp_impl::topPPerSegment(
                context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream);
            workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size;
            return;
        }
        else {
            topp_beam_topk_kernel<T, 1, block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
                                                                                           sorted_id_vals,
                                                                                           sorted_log_probs,
                                                                                           vocab_size,
                                                                                           offset_buf,
                                                                                           begin_offset_buf,
                                                                                           max_top_p,
                                                                                           top_ps,
                                                                                           skip_decode);
            segmented_topp_impl::topPPerSegment(
                context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream);
        }
    }

    constexpr int SAMPLING_BLOCK_SIZE = 256;
    dim3          grid(batch_size);
    topp_sampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sorted_log_probs,
                                                                                    sorted_id_vals,
                                                                                    output_ids,
                                                                                    sequence_length,
                                                                                    finished_buf,
                                                                                    cum_log_probs,
                                                                                    output_log_probs,
                                                                                    begin_offset_buf,
                                                                                    offset_buf + 1,
                                                                                    vocab_size,
                                                                                    curandstate,
                                                                                    max_top_p,
                                                                                    top_ps,
                                                                                    end_ids,
                                                                                    batch_size,
                                                                                    skip_decode);
}

template void invokeBatchTopPSampling(void*           workspace,
                                      size_t&         workspace_size,
                                      size_t&         cub_temp_storage_size,
                                      int*            output_ids,
                                      int*            sequence_length,
                                      bool*           finished_buf,
                                      float*          cum_log_probs,
                                      float*          output_log_probs,
                                      const float*    log_probs,
                                      const int*      id_vals,
                                      int*            offset_buf,
                                      int*            begin_offset_buf,
                                      curandState_t*  curandstate,
                                      const int       batch_size,
                                      const size_t    vocab_size_padded,
                                      const int*      end_ids,
                                      const float     max_top_p,
                                      const float*    top_ps,
                                      cudaStream_t    stream,
                                      cudaDeviceProp* cuda_device_prop,
                                      const bool*     skip_decode);

template void invokeBatchTopPSampling(void*           workspace,
                                      size_t&         workspace_size,
                                      size_t&         cub_temp_storage_size,
                                      int*            output_ids,
                                      int*            sequence_length,
                                      bool*           finished_buf,
                                      float*          cum_log_probs,
                                      float*          output_log_probs,
                                      const half*     log_probs,
                                      const int*      id_vals,
                                      int*            offset_buf,
                                      int*            begin_offset_buf,
                                      curandState_t*  curandstate,
                                      const int       batch_size,
                                      const size_t    vocab_size_padded,
                                      const int*      end_ids,
                                      const float     max_top_p,
                                      const float*    top_ps,
                                      cudaStream_t    stream,
                                      cudaDeviceProp* cuda_device_prop,
                                      const bool*     skip_decode);

template<typename T>
void invokeTopPSampling(void*           workspace,
                        size_t&         workspace_size,
                        size_t&         cub_temp_storage_size,
                        int*            output_ids,
                        int*            sequence_length,
                        bool*           finished_buf,
                        float*          cum_log_probs,
                        float*          output_log_probs,
                        const T*        log_probs,
                        const int*      id_vals,
                        int*            offset_buf,
                        int*            begin_offset_buf,
                        curandState_t*  curandstate,
                        const int       batch_size,
                        const size_t    vocab_size_padded,
                        const int*      end_ids,
                        const float     top_p,
                        cudaStream_t    stream,
                        cudaDeviceProp* cuda_device_prop,
                        const bool*     skip_decode)
{
    invokeBatchTopPSampling(workspace,
                            workspace_size,
                            cub_temp_storage_size,
                            output_ids,
                            sequence_length,
                            finished_buf,
                            cum_log_probs,
                            output_log_probs,
                            log_probs,
                            id_vals,
                            offset_buf,
                            begin_offset_buf,
                            curandstate,
                            batch_size,
                            vocab_size_padded,
                            end_ids,
                            top_p,
                            nullptr,
                            stream,
                            cuda_device_prop,
                            skip_decode);
}

template void invokeTopPSampling(void*           workspace,
                                 size_t&         workspace_size,
                                 size_t&         cub_temp_storage_size,
                                 int*            output_ids,
                                 int*            sequence_length,
                                 bool*           finished_buf,
                                 float*          cum_log_probs,
                                 float*          output_log_probs,
                                 const float*    log_probs,
                                 const int*      id_vals,
                                 int*            offset_buf,
                                 int*            begin_offset_buf,
                                 curandState_t*  curandstate,
                                 const int       batch_size,
                                 const size_t    vocab_size_padded,
                                 const int*      end_ids,
                                 const float     top_p,
                                 cudaStream_t    stream,
                                 cudaDeviceProp* cuda_device_prop,
                                 const bool*     skip_decode);

template void invokeTopPSampling(void*           workspace,
                                 size_t&         workspace_size,
                                 size_t&         cub_temp_storage_size,
                                 int*            output_ids,
                                 int*            sequence_length,
                                 bool*           finished_buf,
                                 float*          cum_log_probs,
                                 float*          output_log_probs,
                                 const half*     log_probs,
                                 const int*      id_vals,
                                 int*            offset_buf,
                                 int*            begin_offset_buf,
                                 curandState_t*  curandstate,
                                 const int       batch_size,
                                 const size_t    vocab_size_padded,
                                 const int*      end_ids,
                                 const float     top_p,
                                 cudaStream_t    stream,
                                 cudaDeviceProp* cuda_device_prop,
                                 const bool*     skip_decode);

template<typename T>
__global__ void
addBiasSoftMax(T* logits, const T* bias, const int* end_ids, const bool* finished, const int n_padded, const int n)
{
    int  bid    = blockIdx.x;
    bool finish = (finished != nullptr) ? finished[bid] : false;
    int  offset = bid * n_padded;

    float            max_val   = -1 * FLT_MAX;
    const bool       IS_FP16   = std::is_same<T, half>::value;
    const T          MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
    __shared__ float s_max_val;
    __shared__ float s_sum_val;

    for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
        if (tid < n) {
            if (finish) {
                logits[offset + tid] = (tid == end_ids[bid]) ? MAX_T_VAL : -MAX_T_VAL;
            }
            else {
                T bias_val = (bias != nullptr) ? bias[tid] : (T)0.0f;
                logits[offset + tid] += bias_val;
            }
        }
        else {
            logits[offset + tid] = -MAX_T_VAL;
        }
        max_val = max(max_val, (float)logits[offset + tid]);
    }

    max_val = blockReduceMax<float>((float)max_val);
    if (threadIdx.x == 0) {
        s_max_val = max_val;
    }
    __syncthreads();

    float sum_val = 0.0f;
    for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
        logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val);
        sum_val += (float)logits[offset + tid];
    }

    sum_val = blockReduceSum<float>(sum_val);
    if (threadIdx.x == 0) {
        s_sum_val = sum_val;
    }
    __syncthreads();

    for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
        logits[offset + tid] = ((float)logits[offset + tid] / (s_sum_val + 1e-6f));
    }
}

template<typename T>
void invokeAddBiasSoftMax(T*           logits,
                          const T*     bias,
                          const int*   end_ids,
                          const bool*  finished,
                          const int    m,
                          const int    n_padded,
                          const int    n,
                          cudaStream_t stream)
{
    dim3 grid(m);
    dim3 block(min(n, 1024));
    /*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */
    addBiasSoftMax<<<grid, block, 0, stream>>>(logits, bias, end_ids, finished, n_padded, n);
}

template void invokeAddBiasSoftMax(float*       logits,
                                   const float* bias,
                                   const int*   end_ids,
                                   const bool*  finished,
                                   const int    m,
                                   const int    n_padded,
                                   const int    n,
                                   cudaStream_t stream);

template void invokeAddBiasSoftMax(half*        logits,
                                   const half*  bias,
                                   const int*   end_ids,
                                   const bool*  finished,
                                   const int    m,
                                   const int    n_padded,
                                   const int    n,
                                   cudaStream_t stream);

__global__ void computeToppDecay(float*         runtime_top_p,
                                 const float*   runtime_initial_top_p,
                                 const int*     output_ids,
                                 const float*   top_p_decay,
                                 const float*   top_p_min,
                                 const int32_t* top_p_reset_ids,
                                 const int      local_batch_size)
{
    /**
     * @brief Compute the topp decay by https://arxiv.org/pdf/2206.04624.pdf
     *        In short, the formula is
     *          runtime_top_p = max(runtime_top_p * top_p_decay, top_p_min)
     *        If generating the top_p_reset_ids, then reset the runtime_top_p.
     *
     * \param runtime_top_p          [local_batch_size]
     * \param runtime_initial_top_p  [local_batch_size]
     * \param output_ids             [local_batch_size]
     * \param top_p_decay            [local_batch_size]
     * \param top_p_min              [local_batch_size]
     * \param top_p_reset_ids         [local_batch_size]
     * \param local_batch_size
     *
     */

    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    if (output_ids[idx] == top_p_reset_ids[idx]) {
        runtime_top_p[idx] = runtime_initial_top_p[idx];
    }
    else {
        runtime_top_p[idx] = max(runtime_top_p[idx] * top_p_decay[idx], top_p_min[idx]);
    }
}

void invokeComputeToppDecay(float*         runtime_top_p,
                            const float*   runtime_initial_top_p,
                            const int*     output_ids,
                            const float*   top_p_decay,
                            const float*   top_p_min,
                            const int32_t* top_p_reset_ids,
                            const int      local_batch_size,
                            cudaStream_t   stream)
{
    dim3 block(min(local_batch_size, 512));
    dim3 grid((local_batch_size + block.x - 1) / block.x);
    computeToppDecay<<<grid, block, 0, stream>>>(
        runtime_top_p, runtime_initial_top_p, output_ids, top_p_decay, top_p_min, top_p_reset_ids, local_batch_size);
}

lvhan028's avatar
lvhan028 committed
1430
}  // namespace turbomind