test_sampling_kernels.cu 36.2 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <algorithm>   // std::fill_n
#include <iostream>    // snprintf
#include <math.h>      // expf, log
#include <stdlib.h>    // rand
#include <string>      // std::string
#include <vector>      // std::vector

#include <cublas_v2.h>
#include <cublasLt.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

lvhan028's avatar
lvhan028 committed
13
14
15
16
17
18
19
20
#include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/layers/DynamicDecodeLayer.h"
#include "src/turbomind/layers/sampling_layers/TopKSamplingLayer.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"
Li Zhang's avatar
Li Zhang committed
21
22
23

#include "tests/unittests/gtest_utils.h"

lvhan028's avatar
lvhan028 committed
24
using namespace turbomind;
Li Zhang's avatar
Li Zhang committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882

namespace {

struct SamplingKernelTestParam {
    size_t batch_size;
    size_t vocab_size;
    size_t beam_width;
    uint   top_k;
    float  top_p;
    size_t output_len;

    std::string toString()
    {
        return fmtstr("SamplingKernelTestParam[batch=%ld, vocab=%ld, beam=%ld, k=%u, p=%3.1f, output_len=%ld]",
                      batch_size,
                      vocab_size,
                      beam_width,
                      top_k,
                      top_p,
                      output_len);
    }
};

/////////////////////////////////// Tests //////////////////////////////////////////

template<typename T>
void computeProb(T* probs, T* logits, int batch_size, int vocab_size)
{
    // Compute the log probability from logits.
    //   logits = batch_size x vocab_size.
    //   probs =  softmax(logits) (softmax along with vocab dimension)
    // float is used for either T=float or half, since operations of half are
    // not fully supported in a host function.
    for (int bidx = 0; bidx < batch_size; ++bidx) {
        float maxval = -FLT_MAX;
        for (int i = 0; i < vocab_size; ++i) {
            float logit = static_cast<float>(logits[bidx * vocab_size + i]);
            if (logit > maxval) {
                maxval = logit;
            }
        }
        float sum = 0.0f;
        for (int i = 0; i < vocab_size; ++i) {
            sum += expf(static_cast<float>(logits[bidx * vocab_size + i]) - maxval);
        }
        for (int i = 0; i < vocab_size; ++i) {
            int idx = bidx * vocab_size + i;
            float logit = static_cast<float>(logits[idx]) - maxval;
            probs[idx] = static_cast<T>(expf(logit) / (sum + EPSILON));
        }
    }
}

template<typename T>
void computeLogProb(T* logprobs, T* logits, int batch_size, int vocab_size)
{
    // Compute the log probability from logits.
    //   logits = batch_size x vocab_size.
    //   logprobs = log(softmax(logits)) (softmax along with vocab dimension)
    // float is used for either T=float or half, since operations of half are
    // not fully supported in a host function.
    for (int bidx = 0; bidx < batch_size; ++bidx) {
        float maxval = -FLT_MAX;
        for (int i = 0; i < vocab_size; ++i) {
            float logit = static_cast<float>(logits[bidx * vocab_size + i]);
            if (logit > maxval) {
                maxval = logit;
            }
        }
        float sum = 0.0f;
        for (int i = 0; i < vocab_size; ++i) {
            sum += expf(static_cast<float>(logits[bidx * vocab_size + i]) - maxval);
        }
        for (int i = 0; i < vocab_size; ++i) {
            int idx = bidx * vocab_size + i;
            float logit = static_cast<float>(logits[idx]) - maxval;
            logprobs[idx] = static_cast<T>(logit - logf(sum + EPSILON));
        }
    }
}

template<typename T>
class SamplingKernelTest: public testing::Test {
public:
    void SetUp() override
    {
        check_cuda_error(cudaStreamCreate(&stream));
        allocator = new Allocator<AllocatorType::CUDA>(getDevice());
        allocator->setStream(stream);
    }
    void TearDown() override
    {
        delete allocator;
        check_cuda_error(cudaStreamDestroy(stream));
    }

protected:
    unsigned long long seed = 0;
    cudaStream_t stream;
    Allocator<AllocatorType::CUDA>* allocator;
    curandState_t* curand_states;
};

template<typename T>
class TopKSamplingKernelTest: public SamplingKernelTest<T> {

protected:
    const int end_id = 0;
    using SamplingKernelTest<T>::seed;
    using SamplingKernelTest<T>::stream;
    using SamplingKernelTest<T>::allocator;
    using SamplingKernelTest<T>::curand_states;

public:
    void runTest(SamplingKernelTestParam param)
    {
        size_t batch_size  = param.batch_size;
        size_t vocab_size  = param.vocab_size;
        size_t output_len  = param.output_len;
        size_t max_seq_len = output_len;

        uint  top_k = param.top_k;
        float top_p = param.top_p;

        // Logit values in the host of shape (batch_size x vocab_size).
        T* h_logits = new T[batch_size * vocab_size];
        T* h_probs  = new T[batch_size * vocab_size];
        T* h_lprobs = new T[batch_size * vocab_size];

        int*  h_output_ids  = new int[batch_size];
        int*  h_seq_lengths = new int[batch_size];
        bool* h_finished    = new bool[batch_size];

        float* expected_cum_lprobs = new float[batch_size];
        std::fill_n(expected_cum_lprobs, batch_size, 0);

        curandState_t* curand_states =
            reinterpret_cast<curandState_t*>(allocator->malloc(sizeof(curandState_t) * batch_size, false));
        invokeCurandInitialize(curand_states, batch_size, seed, stream);

        size_t workspace_size = 0;
        // retrieve the workspace size of the top-k sampling kernel.
        invokeTopKSampling<T>(nullptr,
                              workspace_size,
                              nullptr,
                              nullptr,
                              nullptr,
                              nullptr,
                              nullptr,
                              nullptr,
                              nullptr,
                              top_k,
                              1.0f,
                              vocab_size,
                              nullptr,
                              stream,
                              batch_size,
                              nullptr);
        void* workspace = allocator->malloc(workspace_size);

        int*  end_ids     = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int*  seq_lengths = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        bool* finished    = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));

        T*     probs         = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size * vocab_size));
        float* cum_lprobs    = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size));
        float* output_lprobs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * output_len * batch_size));
        int*   output_ids    = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * max_seq_len * batch_size));

        // Init by zero.
        deviceFill(seq_lengths, batch_size, 0);
        deviceFill(finished, batch_size, false);
        deviceFill(end_ids, batch_size, end_id);

        deviceFill(cum_lprobs, batch_size, 0.0f);
        deviceFill(output_lprobs, output_len * batch_size, 0.0f);
        deviceFill(output_ids, max_seq_len * batch_size, 0);

        for (size_t step = 0; step < output_len; ++step) {
            initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);
            computeProb(h_probs, h_logits, batch_size, vocab_size);
            cudaH2Dcpy(probs, h_probs, batch_size * vocab_size);
            invokeTopKSampling(workspace,
                               workspace_size,
                               // Note that the kernel needs vocab probs instead of
                               // log-prob if cum_log_probs or output_log_probs are
                               // provided. It's because the sampling layer already
                               // preprocesses log_prob_buf when those are provided.
                               probs,
                               output_ids + step * batch_size,
                               seq_lengths,
                               finished,
                               cum_lprobs,
                               output_lprobs + step * batch_size,
                               curand_states,
                               top_k,
                               top_p,
                               vocab_size,
                               end_ids,
                               stream,
                               batch_size,
                               nullptr);

            // Compute reference.
            cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size);
            cudaD2Hcpy(h_seq_lengths, seq_lengths, batch_size);
            cudaD2Hcpy(h_finished, finished, batch_size);
            computeLogProb(h_lprobs, h_logits, batch_size, vocab_size);
            for (size_t i = 0; i < batch_size; ++i) {
                int idx = i * vocab_size + h_output_ids[i];
                expected_cum_lprobs[i] += (int)step < h_seq_lengths[i] ? (float)h_lprobs[idx] : 0.0f;
                EXPECT_EQ(h_finished[i], h_output_ids[i] == end_id);
            }
        }
        bool passed = checkResult(param.toString(), cum_lprobs, expected_cum_lprobs, batch_size);
        EXPECT_TRUE(passed);

        delete[] expected_cum_lprobs;
        delete[] h_seq_lengths;
        delete[] h_logits;
        delete[] h_lprobs;
        delete[] h_probs;
        delete[] h_output_ids;
    }

    void runBatchTest(SamplingKernelTestParam param, bool has_diff_runtime_args, bool use_skip_decode)
    {
        size_t batch_size = param.batch_size;
        size_t vocab_size = param.vocab_size;
        size_t output_len = param.output_len;
        size_t seq_len    = output_len;

        int   top_k = param.top_k;
        float top_p = param.top_p;

        int*   h_top_ks = new int[batch_size];
        float* h_top_ps = new float[batch_size];
        for (size_t i = 0; i < batch_size; ++i) {
            h_top_ks[i] = (!has_diff_runtime_args || i % 3 == 0) ? top_k : 1;
            h_top_ps[i] = (!has_diff_runtime_args || i % 3 == 0) ? top_p : 0.1 * top_p;
        }
        int max_top_k = *std::max_element(h_top_ks, h_top_ks + batch_size);

        // Logit values in the host of shape (batch_size x vocab_size).
        T* h_logits = new T[batch_size * vocab_size];
        T* h_probs  = new T[batch_size * vocab_size];
        T* h_lprobs = new T[batch_size * vocab_size];

        float* expected_cum_lprobs = new float[batch_size];

        int*  h_output_ids  = new int[batch_size];
        int*  h_seq_lengths = new int[batch_size];
        bool* h_finished    = new bool[batch_size];
        bool* h_skip_decode = new bool[batch_size];

        initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);
        std::fill_n(expected_cum_lprobs, batch_size, 0);
        for (size_t i = 0; i < batch_size; ++i) {
            h_skip_decode[i] = use_skip_decode && (i % 2 == 0);
        }

        curandState_t* curand_states =
            reinterpret_cast<curandState_t*>(allocator->malloc(sizeof(curandState_t) * batch_size, false));
        invokeCurandInitialize(curand_states, batch_size, seed, stream);

        size_t workspace_size = 0;
        // retrieve the workspace size of the top-k sampling kernel.
        invokeBatchTopKSampling<T>(nullptr,  // workspace
                                   workspace_size,
                                   nullptr,  // log_probs
                                   nullptr,  // ids
                                   nullptr,  // sequence_lengths
                                   nullptr,  // finished
                                   nullptr,  // cum_log_probs
                                   nullptr,  // output_log_probs
                                   nullptr,  // curandstates
                                   max_top_k,
                                   nullptr,  // top_ks
                                   1.0f,
                                   nullptr,
                                   vocab_size,
                                   nullptr,  // end_ids
                                   stream,
                                   batch_size,
                                   nullptr);
        void* workspace = allocator->malloc(workspace_size, false);

        int*   top_ks = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        float* top_ps = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size));

        int*  end_ids     = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int*  seq_lengths = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int*  output_ids  = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * seq_len * batch_size));
        bool* finished    = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));
        bool* skip_decode = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));

        T*     probs         = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size * vocab_size, true));
        float* cum_lprobs    = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size));
        float* output_lprobs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * output_len * batch_size));

        // Initialize.
        cudaH2Dcpy(top_ks, h_top_ks, batch_size);
        cudaH2Dcpy(top_ps, h_top_ps, batch_size);
        cudaH2Dcpy(skip_decode, h_skip_decode, batch_size);

        deviceFill(end_ids, batch_size, end_id);
        deviceFill(seq_lengths, batch_size, 0);
        deviceFill(finished, batch_size, false);
        deviceFill(cum_lprobs, batch_size, 0.0f);
        deviceFill(output_lprobs, output_len * batch_size, 0.0f);
        deviceFill(output_ids, seq_len * batch_size, 0);

        for (size_t step = 0; step < output_len; ++step) {
            initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);
            computeProb(h_probs, h_logits, batch_size, vocab_size);
            cudaH2Dcpy(probs, h_probs, batch_size * vocab_size);

            invokeBatchTopKSampling(workspace,
                                    workspace_size,
                                    // Note that the kernel needs vocab probs instead of
                                    // log-prob if cum_log_probs or output_log_probs are
                                    // provided. It's because the sampling layer already
                                    // preprocesses log_prob_buf when those are provided.
                                    probs,
                                    output_ids + step * batch_size,
                                    seq_lengths,
                                    finished,
                                    cum_lprobs,
                                    output_lprobs + step * batch_size,
                                    curand_states,
                                    max_top_k,
                                    top_ks,
                                    1.0f,
                                    nullptr,
                                    vocab_size,
                                    end_ids,
                                    stream,
                                    batch_size,
                                    skip_decode);

            // Compute reference.
            cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size);
            cudaD2Hcpy(h_seq_lengths, seq_lengths, batch_size);
            cudaD2Hcpy(h_finished, finished, batch_size);
            computeLogProb(h_lprobs, h_logits, batch_size, vocab_size);
            for (size_t i = 0; i < batch_size; ++i) {
                if (!h_skip_decode[i]) {
                    int idx = i * vocab_size + h_output_ids[i];
                    expected_cum_lprobs[i] += (int)step < h_seq_lengths[i] ? (float)h_lprobs[idx] : 0.0f;
                    EXPECT_EQ(h_finished[i], h_output_ids[i] == end_id);
                }
            }
        }
        bool passed = checkResult(param.toString(), cum_lprobs, expected_cum_lprobs, batch_size);
        EXPECT_TRUE(passed) << "Fail subtest (has_diff_runtime_args: " << has_diff_runtime_args
                            << ", skip_decode: " << use_skip_decode << ")";

        delete[] expected_cum_lprobs;
        delete[] h_seq_lengths;
        delete[] h_logits;
        delete[] h_lprobs;
        delete[] h_probs;
        delete[] h_output_ids;
        delete[] h_top_ks;
        delete[] h_skip_decode;
    }

    void runBatchTest(SamplingKernelTestParam param)
    {
        this->runBatchTest(param, false, false);
        this->runBatchTest(param, false, true);
        this->runBatchTest(param, true,  false);
        this->runBatchTest(param, true,  true);
    }
};

TYPED_TEST_SUITE(TopKSamplingKernelTest, FloatAndHalfTypes);

TYPED_TEST(TopKSamplingKernelTest, CorrectnessGreedy)
{
    this->runTest({6, 4, 1, 1, 1.0f, 1});
};

TYPED_TEST(TopKSamplingKernelTest, CorrectnessAncestral)
{
    this->runTest({6, 4, 1, 4, 1.0f, 1});
};


TYPED_TEST(TopKSamplingKernelTest, CorrectnessLargeK63)
{
    this->runTest({16, 51200, 1, 63, 1.0f, 8});
};

TYPED_TEST(TopKSamplingKernelTest, CorrectnessLargeK1024)
{
    this->runTest({16, 51200, 1, 1024, 1.0f, 8});
};

TYPED_TEST(TopKSamplingKernelTest, CorrectnessTopKTopP)
{
    this->runTest({16, 4000, 1, 63, 0.3f, 8});
};

TYPED_TEST(TopKSamplingKernelTest, NotSupportedLargerThanK1024)
{
    EXPECT_THROW(this->runTest({16, 4000, 1, 1025, 1.0f, 8}), std::domain_error);
};

TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessGreedy)
{
    this->runBatchTest({6, 4, 1, 1, 1.0f, 1});
};

TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessAncestral)
{
    this->runBatchTest({6, 4, 1, 4, 1.0f, 1});
};

TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessLargeK63)
{
    this->runBatchTest({8, 4000, 1, 63, 1.0f, 8});
};

TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessLargeK1024)
{
    this->runBatchTest({8, 4000, 1, 1024, 0.0f, 8});
};

TYPED_TEST(TopKSamplingKernelTest, BatchCorrectnessTopKTopP)
{
    this->runBatchTest({8, 4000, 1, 63, 0.3f, 8});
};


template<typename T>
class TopPSamplingKernelTest: public SamplingKernelTest<T> {

protected:
    const int end_id = 0;
    using SamplingKernelTest<T>::seed;
    using SamplingKernelTest<T>::stream;
    using SamplingKernelTest<T>::allocator;
    using SamplingKernelTest<T>::curand_states;

public:
    void runTest(SamplingKernelTestParam param)
    {
        size_t batch_size = param.batch_size;
        size_t vocab_size = param.vocab_size;
        size_t output_len = param.output_len;
        size_t seq_len = output_len;

        float top_p = param.top_p;

        // Logit values in the host of shape (batch_size x vocab_size).
        T* h_logits = new T[batch_size * vocab_size];
        T* h_probs  = new T[batch_size * vocab_size];
        T* h_lprobs = new T[batch_size * vocab_size];

        float* expected_cum_lprobs = new float[batch_size];
        std::fill_n(expected_cum_lprobs, batch_size, 0);

        int*  h_output_ids  = new int[batch_size];
        int*  h_seq_lengths = new int[batch_size];
        bool* h_finished    = new bool[batch_size];

        initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);

        int device;
        cudaGetDevice(&device);
        struct cudaDeviceProp device_prop;
        cudaGetDeviceProperties(&device_prop, device);

        curandState_t* curand_states = reinterpret_cast<curandState_t*>(
            allocator->malloc(sizeof(curandState_t) * batch_size, false));
        invokeCurandInitialize(curand_states, batch_size, seed, stream);

        int* end_ids     = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int* seq_lengths = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int* output_ids  = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * seq_len * batch_size));

        bool* finished    = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));
        bool* skip_decode = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));

        T*     probs         = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size * vocab_size));
        float* cum_lprobs    = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size));
        float* output_lprobs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * output_len * batch_size));

        int* begin_offsets    = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * (batch_size + 1)));
        int* end_offsets      = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * (batch_size + 1)));
        int* topp_id_vals_buf = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size * vocab_size));

        size_t workspace_size = 0;
        size_t cub_temp_storage_size = 0;
        // retrieve the workspace size of the top-p sampling kernel.
        invokeTopPSampling<T>(nullptr,  // workspace
                              workspace_size,
                              cub_temp_storage_size,
                              nullptr,  // output_ids
                              nullptr,  // sequence_length
                              nullptr,  // finished_buffer
                              nullptr,  // cum_log_probs
                              nullptr,  // output_log_probs
                              (T*)nullptr,  // log_probs
                              topp_id_vals_buf,
                              end_offsets,
                              begin_offsets,
                              curand_states,
                              batch_size,
                              vocab_size,
                              nullptr,
                              top_p,
                              stream,
                              &device_prop,
                              nullptr);
        void* workspace = allocator->malloc(workspace_size);

        // Initialize.
        deviceFill(end_ids, batch_size, end_id);
        deviceFill(seq_lengths, batch_size, 0);
        deviceFill(finished, batch_size, false);
        deviceFill(cum_lprobs, batch_size, 0.0f);
        deviceFill(output_lprobs, output_len * batch_size, 0.0f);
        deviceFill(output_ids, seq_len * batch_size, 0);

        for (size_t step = 0; step < output_len; ++step) {
            initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);
            computeProb(h_probs, h_logits, batch_size, vocab_size);
            cudaH2Dcpy(probs, h_probs, batch_size * vocab_size);

            invokeTopPInitialize(topp_id_vals_buf,
                                 end_offsets,
                                 begin_offsets,
                                 batch_size,
                                 vocab_size,
                                 stream);

            invokeTopPSampling<T>(workspace,
                                  workspace_size,
                                  cub_temp_storage_size,
                                  output_ids + step * batch_size,
                                  seq_lengths,
                                  finished,
                                  cum_lprobs,
                                  output_lprobs + step * batch_size,
                                  // Note that the kernel needs vocab probs instead of
                                  // log-prob if cum_log_probs or output_log_probs are
                                  // provided. It's because the sampling layer already
                                  // preprocesses log_prob_buf when those are provided.
                                  probs,
                                  topp_id_vals_buf,
                                  end_offsets,
                                  begin_offsets,
                                  curand_states,
                                  batch_size,
                                  vocab_size,
                                  end_ids,
                                  top_p,
                                  stream,
                                  &device_prop,
                                  nullptr);

            // Compute reference.
            cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size);
            cudaD2Hcpy(h_seq_lengths, seq_lengths, batch_size);
            cudaD2Hcpy(h_finished, finished, batch_size);
            computeLogProb(h_lprobs, h_logits, batch_size, vocab_size);
            for (size_t i = 0; i < batch_size; ++i) {
                int idx = i * vocab_size + h_output_ids[i];
                expected_cum_lprobs[i] += (int)step < h_seq_lengths[i] ? (float)h_lprobs[idx] : 0.0f;
                EXPECT_EQ(h_finished[i], h_output_ids[i] == end_id);
            }
        }
        bool passed = checkResult(param.toString(), cum_lprobs, expected_cum_lprobs, batch_size);
        EXPECT_TRUE(passed);

        delete[] expected_cum_lprobs;
        delete[] h_seq_lengths;
        delete[] h_logits;
        delete[] h_lprobs;
        delete[] h_probs;
        delete[] h_output_ids;
    }

    void runBatchTest(SamplingKernelTestParam param, bool has_diff_runtime_args, bool use_skip_decode)
    {
        size_t batch_size = param.batch_size;
        size_t vocab_size = param.vocab_size;

        float top_p = param.top_p;
        float* h_top_ps = new float[batch_size];
        // Initialize runtime top k values.
        for (size_t i = 0; i < batch_size; ++i) {
            h_top_ps[i] = (!has_diff_runtime_args || i % 3 == 0) ? top_p : 0.1 * top_p;
        }
        float max_top_p = *std::max_element(h_top_ps, h_top_ps + batch_size);

        size_t output_len = param.output_len;
        size_t seq_len = output_len;

        // Logit values in the host of shape (batch_size x vocab_size).
        T* h_logits = new T[batch_size * vocab_size];
        T* h_probs  = new T[batch_size * vocab_size];
        T* h_lprobs = new T[batch_size * vocab_size];

        float* expected_cum_lprobs = new float[batch_size];
        std::fill_n(expected_cum_lprobs, batch_size, 0);

        int*  h_output_ids  = new int[batch_size];
        int*  h_seq_lengths = new int[batch_size];
        bool* h_finished    = new bool[batch_size];
        bool* h_skip_decode = new bool[batch_size];

        initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);
        std::fill_n(expected_cum_lprobs, batch_size, 0);
        for (size_t i = 0; i < batch_size; ++i) {
            h_skip_decode[i] = use_skip_decode && (i % 2 == 0);
        }

        int device;
        cudaGetDevice(&device);
        struct cudaDeviceProp device_prop;
        cudaGetDeviceProperties(&device_prop, device);

        curandState_t* curand_states = reinterpret_cast<curandState_t*>(
            allocator->malloc(sizeof(curandState_t) * batch_size, false));
        invokeCurandInitialize(curand_states, batch_size, seed, stream);

        float* top_ps = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size));

        int* end_ids     = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int* seq_lengths = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size));
        int* output_ids  = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * seq_len * batch_size));

        bool* finished    = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));
        bool* skip_decode = reinterpret_cast<bool*>(allocator->malloc(sizeof(bool) * batch_size));

        T*     probs         = reinterpret_cast<T*>(allocator->malloc(sizeof(T) * batch_size * vocab_size));
        float* cum_lprobs    = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * batch_size));
        float* output_lprobs = reinterpret_cast<float*>(allocator->malloc(sizeof(float) * output_len * batch_size));

        int* begin_offsets    = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * (batch_size + 1)));
        int* end_offsets      = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * (batch_size + 1)));
        int* topp_id_vals_buf = reinterpret_cast<int*>(allocator->malloc(sizeof(int) * batch_size * vocab_size));

        size_t workspace_size = 0;
        size_t cub_temp_storage_size = 0;
        // retrieve the workspace size of the top-p sampling kernel.
        invokeBatchTopPSampling<T>(nullptr,  // workspace
                                   workspace_size,
                                   cub_temp_storage_size,
                                   nullptr,  // output_ids
                                   nullptr,  // sequence_length
                                   nullptr,  // finished_buffer
                                   nullptr,  // cum_log_probs
                                   nullptr,  // output_log_probs
                                   (T*)nullptr,  // log_probs
                                   topp_id_vals_buf,
                                   end_offsets,
                                   begin_offsets,
                                   curand_states,
                                   batch_size,
                                   vocab_size,
                                   nullptr,
                                   max_top_p,
                                   top_ps,
                                   stream,
                                   &device_prop,
                                   nullptr);
        void* workspace = allocator->malloc(workspace_size);

        // Initialize.
        cudaH2Dcpy(top_ps, h_top_ps, batch_size);
        cudaH2Dcpy(skip_decode, h_skip_decode, batch_size);
        deviceFill(end_ids, batch_size, end_id);
        deviceFill(seq_lengths, batch_size, 0);
        deviceFill(finished, batch_size, false);
        deviceFill(cum_lprobs, batch_size, 0.0f);
        deviceFill(output_lprobs, output_len * batch_size, 0.0f);
        deviceFill(output_ids, seq_len * batch_size, 0);

        for (size_t step = 0; step < output_len; ++step) {
            initRandom(h_logits, batch_size * vocab_size, -3.0f, 3.0f);
            computeProb(h_probs, h_logits, batch_size, vocab_size);
            cudaH2Dcpy(probs, h_probs, batch_size * vocab_size);

            invokeTopPInitialize(topp_id_vals_buf,
                                 end_offsets,
                                 begin_offsets,
                                 batch_size,
                                 vocab_size,
                                 stream);

            invokeBatchTopPSampling<T>(workspace,
                                       workspace_size,
                                       cub_temp_storage_size,
                                       output_ids + step * batch_size,
                                       seq_lengths,
                                       finished,
                                       cum_lprobs,
                                       output_lprobs + step * batch_size,
                                       // Note that the kernel needs vocab probs instead of
                                       // log-prob if cum_log_probs or output_log_probs are
                                       // provided. It's because the sampling layer already
                                       // preprocesses log_prob_buf when those are provided.
                                       probs,
                                       topp_id_vals_buf,
                                       end_offsets,
                                       begin_offsets,
                                       curand_states,
                                       batch_size,
                                       vocab_size,
                                       end_ids,
                                       max_top_p,
                                       top_ps,
                                       stream,
                                       &device_prop,
                                       skip_decode);

            // Compute reference.
            cudaD2Hcpy(h_output_ids, output_ids + step * batch_size, batch_size);
            cudaD2Hcpy(h_seq_lengths, seq_lengths, batch_size);
            cudaD2Hcpy(h_finished, finished, batch_size);
            computeLogProb(h_lprobs, h_logits, batch_size, vocab_size);
            for (size_t i = 0; i < batch_size; ++i) {
                if (!h_skip_decode[i]) {
                    int idx = i * vocab_size + h_output_ids[i];
                    expected_cum_lprobs[i] += (int)step < h_seq_lengths[i] ? (float)h_lprobs[idx] : 0.0f;
                    EXPECT_EQ(h_finished[i], h_output_ids[i] == end_id);
                }
            }
        }
        bool passed = checkResult(param.toString(), cum_lprobs, expected_cum_lprobs, batch_size);
        EXPECT_TRUE(passed) << "Fail subtest (has_diff_runtime_args: " << has_diff_runtime_args
                            << ", skip_decode: " << use_skip_decode << ")";

        delete[] expected_cum_lprobs;
        delete[] h_seq_lengths;
        delete[] h_logits;
        delete[] h_lprobs;
        delete[] h_probs;
        delete[] h_output_ids;
        delete[] h_top_ps;
        delete[] h_skip_decode;
    }

    void runBatchTest(SamplingKernelTestParam param)
    {
        this->runBatchTest(param, false, false);
        this->runBatchTest(param, false, true);
        this->runBatchTest(param, true,  false);
        this->runBatchTest(param, true,  true);
    }
};

TYPED_TEST_SUITE(TopPSamplingKernelTest, FloatAndHalfTypes);

TYPED_TEST(TopPSamplingKernelTest, CorrectnessSmallP)
{
    this->runTest({6, 4, 1, 0, 0.2f, 1});
};

TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeP)
{
    this->runTest({6, 4, 1, 0, 0.9f, 1});
};

TYPED_TEST(TopPSamplingKernelTest, CorrectnessAncestral)
{
    this->runTest({6, 4, 1, 0, 1.0f, 1});
};

TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabSmallP)
{
    this->runTest({32, 51200, 1, 0, 0.2f, 16});
};

TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabLargeP)
{
    this->runTest({32, 51200, 1, 0, 0.9f, 16});
};

TYPED_TEST(TopPSamplingKernelTest, BatchCorrectnessSmallP)
{
    this->runBatchTest({6, 4, 1, 0, 0.2f, 1});
};

TYPED_TEST(TopPSamplingKernelTest, BatchCorrectnessLargeP)
{
    this->runBatchTest({6, 4, 1, 0, 0.9f, 1});
};

TYPED_TEST(TopPSamplingKernelTest, BatchCorrectnessSmallP2)
{
    this->runBatchTest({8, 4000, 1, 0, 0.2f, 16});
};

TYPED_TEST(TopPSamplingKernelTest, BatchCorrectnessLargeP2)
{
    this->runBatchTest({8, 4000, 1, 0, 0.9f, 16});
};

__global__
void generateRandomNumber(unsigned int *vals, curandState_t *states, const int batch_size) {
    int idx = threadIdx.x;
    if (idx < batch_size) {
        vals[idx] = curand(states + idx);
    }
}

TEST(SamplingKernelTest, CurandBatchInitialize) {
    size_t batch_size = 127;
    cudaStream_t stream;
    cudaStreamCreate(&stream);

    curandState_t* curand_states;
    check_cuda_error(cudaMalloc(&curand_states, sizeof(curandState_t) * batch_size));
    unsigned long long* h_random_seeds = new unsigned long long[batch_size];
    const size_t period_size = 3;
    for (size_t i = 0; i < batch_size; ++i) {
        h_random_seeds[i] = i / period_size;
    }
    unsigned long long* d_random_seeds;
    check_cuda_error(cudaMalloc(&d_random_seeds, sizeof(unsigned long long) * batch_size));
    check_cuda_error(cudaMemcpy(d_random_seeds, h_random_seeds,
                                sizeof(unsigned long long) * batch_size, cudaMemcpyHostToDevice));

    // Initialize curand states.
    invokeCurandBatchInitialize(curand_states, batch_size, d_random_seeds, stream);
    sync_check_cuda_error();

    // Generate random numbers using initialized curand states.
    unsigned int* d_rand_vals;
    unsigned int* h_rand_vals = new unsigned int[batch_size];
    check_cuda_error(cudaMalloc(&d_rand_vals, sizeof(unsigned int) * batch_size));
    generateRandomNumber<<<1, batch_size, 0, stream>>>(d_rand_vals, curand_states, batch_size);
    check_cuda_error(cudaMemcpyAsync(
        h_rand_vals, d_rand_vals, sizeof(unsigned int) * batch_size, cudaMemcpyDeviceToHost, stream));
    check_cuda_error(cudaStreamSynchronize(stream));

    // The same seed produces the same random number.
    for (size_t i = 0; i + period_size - 1 < batch_size; i += period_size) {
        for (size_t j = 1; j < period_size; ++j) {
            EXPECT_TRUE(h_rand_vals[i] == h_rand_vals[i + j])
                << fmtstr("Fail at val[%d]=%d <> val[%d]=%d", i, h_rand_vals[i], i + j, h_rand_vals[i + j]);
        }
    }

    delete h_rand_vals;
    delete h_random_seeds;
    check_cuda_error(cudaFree(d_rand_vals));
    check_cuda_error(cudaFree(d_random_seeds));
    check_cuda_error(cudaFree(curand_states));
    check_cuda_error(cudaStreamDestroy(stream));
}

}  // end of namespace