logprob_kernels.cu 9.67 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif

lvhan028's avatar
lvhan028 committed
29
30
31
#include "src/turbomind/kernels/logprob_kernels.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
32

lvhan028's avatar
lvhan028 committed
33
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

template<typename T>
__global__ void log_probs_kernel(float*       log_probs,
                                 const T*     logits,
                                 const int*   ids,
                                 const int*   lengths,
                                 const size_t max_input_length,
                                 const size_t batch_size,
                                 const size_t vocab_size,
                                 const size_t vocab_size_padded,
                                 bool         batch_first)
{
    // Calculate the log probability from logits.
    //   log_probs[t, :] = log(softmax(logits))[ids[t + 1, :]]
    //
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length -1],
    //     log probabilities of each token.
    // logits: [max_length, batch_size, vocab_size_padded] or [batch_size, max_length, vocab_size_padded]
    // lengths: [batch_size], sequence lengths
    // ids: [max_length, batch_size], token ids.
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.
    // vocab_size: [1], vocab_size,
    // vocab_size: [1], vocab_size_padded, padded vocab size.

    const bool IS_FP16   = std::is_same<T, half>::value;
    const T    MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;

    int tidx = threadIdx.x;                            // vocab dim
    int bidx = batch_first ? blockIdx.x : blockIdx.y;  // batch dim
    int step = batch_first ? blockIdx.y : blockIdx.x;  // step dim

    __shared__ float s_max_logit;

    if (bidx < batch_size && step < lengths[bidx] - 1) {
        // reposition logits to data for the current batch.
        int step_offset  = batch_first ? step * vocab_size_padded : step * batch_size * vocab_size_padded;
        int batch_offset = batch_first ? bidx * max_input_length * vocab_size_padded : bidx * vocab_size_padded;
        logits += step_offset + batch_offset;

        // Find max(logits).
        float local_max = -MAX_T_VAL;
        float val       = -MAX_T_VAL;
        for (int i = tidx; i < vocab_size; i += blockDim.x) {
            val       = static_cast<float>(logits[i]);
            local_max = fmax(local_max, val);
        }

        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
        if (tidx == 0) {
            s_max_logit = max_val;
        }
        __syncthreads();

        // Calculate the denominator: sum_i exp(logits[i])
        float local_sum_exp = 0.0f;
        for (int i = tidx; i < vocab_size; i += blockDim.x) {
            val = __expf(static_cast<float>(logits[i]) - s_max_logit);
            local_sum_exp += val;
        }

        float sum_exp = blockDim.x <= 32 ? warpReduceSum(local_sum_exp) : blockReduceSum<float>(local_sum_exp);
        if (tidx == 0) {
            int idx = batch_first ? step + bidx * (max_input_length - 1) : step * batch_size + bidx;
            // log_probs[step, ...] is the log probability of a token at step t + 1.
            int token_idx  = batch_first ? step + 1 + bidx * max_input_length : (step + 1) * batch_size + bidx;
            log_probs[idx] = static_cast<float>(logits[ids[token_idx]]) - s_max_logit - __logf(sum_exp + 1e-9f);
        }
    }
}

__global__ void accumulate_log_probs(float*       cum_log_probs,
                                     const float* log_probs,
                                     const int*   lengths,
                                     const size_t max_input_length,
                                     const size_t batch_size,
                                     const bool   batch_first)
{
    // Accumulate the log probability along with the sequence dimension.
    //   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
    //
    // cum_log_probs: [batch_size], cumulative log probability
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
    //   log probability of each token
    // lengths: [batch_size], sequence lengths
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.

    int bidx = blockIdx.x;   // batch dim
    int tidx = threadIdx.x;  // step dim

    if (bidx < batch_size) {
        int length = lengths[bidx];
        // reposition logits to data for the current batch.
        log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
        int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.
        float local_accum = 0.0f;
        for (int step = tidx; step < length - 1; step += blockDim.x) {
            local_accum += static_cast<float>(log_probs[step * stride]);
        }
        float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);
        if (tidx == 0) {
            cum_log_probs[bidx] = accum;
        }
    }
}

template<typename T>
void invokeLogProbFromLogits(float*       cum_log_probs,
                             const T*     logits,
                             const int*   input_ids,
                             const int*   input_lengths,
                             const size_t max_input_length,
                             const size_t batch_size,
                             const size_t vocab_size,
                             const size_t vocab_size_padded,
                             void*        workspace,
                             const size_t workspace_size,
                             cudaStream_t stream,
                             const bool   batch_first)
{
    // A batched version of log prob computation.
    //
    // cum_log_probs: [batch_size]
    // logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
    // input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
    // input_lengths: [batch_size]
    // workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.

lvhan028's avatar
lvhan028 committed
161
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    // block_size should be multiple of 32 to use warpReduceMax.
    const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
    assert(block_size % 32 == 0);
    assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
    assert(vocab_size <= vocab_size_padded);

    float* log_probs = reinterpret_cast<float*>(workspace);
    int    gx        = batch_first ? batch_size : max_input_length - 1;
    int    gy        = batch_first ? max_input_length - 1 : batch_size;
    dim3   grid(gx, gy);
    log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,
                                                         logits,
                                                         input_ids,
                                                         input_lengths,
                                                         max_input_length,
                                                         batch_size,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         batch_first);
    accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(
        cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
template void invokeLogProbFromLogits(float*       cum_log_probs,
                                      const float* logits,
                                      const int*   input_ids,
                                      const int*   input_lengths,
                                      const size_t max_input_length,
                                      const size_t batch_size,
                                      const size_t vocab_size,
                                      const size_t vocab_size_padded,
                                      void*        workspace,
                                      const size_t workspace_size,
                                      cudaStream_t stream,
                                      const bool   batch_first);

template void invokeLogProbFromLogits(float*       cum_log_probs,
                                      const half*  logits,
                                      const int*   input_ids,
                                      const int*   input_lengths,
                                      const size_t max_input_length,
                                      const size_t batch_size,
                                      const size_t vocab_size,
                                      const size_t vocab_size_padded,
                                      void*        workspace,
                                      const size_t workspace_size,
                                      cudaStream_t stream,
                                      const bool   batch_first);
lvhan028's avatar
lvhan028 committed
210
}  // end of namespace turbomind