Unverified Commit 9efcac38 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

check-in fastertransformer (#7)

* add ft code

* gitignore

* fix lint

* revert fmha
parent 720fc533
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <cuda_runtime.h>
#include <stdint.h>
namespace fastertransformer {
enum class PositionEmbeddingType {
relative,
absolute,
};
template<typename T, typename Tindex>
void invokeGenRelativePosBias(T* relative_position_bias,
const T* relative_position_bias_table,
const Tindex* relative_position_bias_index,
const int window_size,
const int head_num,
cudaStream_t stream);
template<typename T>
void invokeBuildAlibiSlopes(T* linear_position_bias_slopes, const size_t head_num, cudaStream_t stream);
template<typename T, typename Tindex>
void invokeGenRelativePosBiasV2(T* relative_position_bias,
const T* relative_coords_table,
const Tindex* relative_position_bias_index,
const T* cpb_mlp_weight1,
const T* cpb_mlp_bias1,
const T* cpb_mlp_weight2,
const int window_size,
const int cpb_mlp_in_dim,
const int cpb_mlp_out_dim,
const int head_num,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/fastertransformer/kernels/gpt_kernels.h"
#include "src/fastertransformer/utils/memory_utils.h"
namespace fastertransformer {
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T, bool OUTPUT_ID, int PROMPT_SRC>
__global__ void start_id_embedding_position_lookups_kernel(T* from_tensor,
int* output_ids,
const T* embedding_table,
const T* pos_table,
pPromptTuningParam<T> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int64_t hidden_units)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * length * hidden_units;
index += blockDim.x * gridDim.x) {
// transpose the input_ids [batch, length] (part of [batch, max_length]) to output_ids [length, batch]
if (OUTPUT_ID && index < batch_size * max_length) {
// for p/prompt_tuning (have prompt templates like [input1, prompt1, input2, prompt2])
// we have to process it to like [input1, input2, prompt1, prompt2], and then remove the prompts during post
// processing
if (PROMPT_SRC > 0) {
if (index < batch_size) {
int no_prompt_output_seq_id = 0;
#pragma unroll 1
for (int seq_id = 0; seq_id < max_length; seq_id++) {
int current_input_id = input_ids[index * max_length + seq_id];
if (current_input_id < prompt_param.p_prompt_tuning_id_start) {
output_ids[no_prompt_output_seq_id * batch_size + index] = current_input_id;
no_prompt_output_seq_id++;
}
}
}
}
else {
const int seq_id = index % max_length;
const int batch_id = index / max_length;
if (seq_id < length) {
output_ids[seq_id * batch_size + batch_id] = input_ids[index];
}
}
}
// embedding lookup from word ids [batch, length] (part of [batch, max_length]) and [vocab, hidden] to generate
// embedding [batch, length, hidden]
const int word_index = index / hidden_units;
const int word_index_row = word_index / length; // batch_id
const int word_index_col = word_index % length;
const int real_word_index = word_index_row * max_length + word_index_col;
const int step = start_step + word_index % length;
const int col_index = index % hidden_units;
const int input_id = input_ids == nullptr ? real_word_index : input_ids[real_word_index];
const int prompt_id = input_id - prompt_param.p_prompt_tuning_id_start;
T embedding = (T)0.0f;
if (PROMPT_SRC > 0 && prompt_id >= 0) {
if (PROMPT_SRC == 1) {
// from loaded prompt embedding tables
embedding =
prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index];
}
else {
// from request prompt embedding
embedding =
prompt_param
.request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units
+ prompt_id * hidden_units + col_index];
}
}
else {
embedding = embedding_table[input_id * hidden_units + col_index];
}
T pos_embed = pos_table == nullptr ? (T)0.f : pos_table[(step - 1) * hidden_units + col_index];
from_tensor[index] = embedding + pos_embed;
}
}
#define WORD_POS_EMBEDDING_LOOPUP_KERNEL(OUTPUT_ID, PROMPT_SRC) \
start_id_embedding_position_lookups_kernel<T, OUTPUT_ID, PROMPT_SRC><<<grid, block, 0, stream>>>(from_tensor, \
output_ids, \
embedding_table, \
pos_table, \
prompt_param, \
input_ids, \
start_step, \
length, \
max_length, \
batch_size, \
hidden_units);
template<typename T>
void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor,
int* output_ids,
const T* embedding_table, // can also be inputs_embeds
const T* pos_table,
pPromptTuningParam<T> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream)
{
dim3 grid(min(batch_size * length, 65536));
dim3 block(min(hidden_units, 512));
const bool has_output_ids = output_ids != nullptr;
FT_CHECK(!(has_output_ids && input_ids == nullptr));
if (has_output_ids) {
if (prompt_param.use_request_p_prompt_embedding) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 2);
}
else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 1);
}
else {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(true, 0);
}
}
else {
if (prompt_param.use_request_p_prompt_embedding) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 2);
}
else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 1);
}
else {
WORD_POS_EMBEDDING_LOOPUP_KERNEL(false, 0);
}
}
}
template void invokeInputIdsEmbeddingLookupPosEncoding(float* from_tensor,
int* output_ids,
const float* embedding_table,
const float* pos_table,
pPromptTuningParam<float> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
template void invokeInputIdsEmbeddingLookupPosEncoding(half* from_tensor,
int* output_ids,
const half* embedding_table,
const half* pos_table,
pPromptTuningParam<half> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeInputIdsEmbeddingLookupPosEncoding(__nv_bfloat16* from_tensor,
int* output_ids,
const __nv_bfloat16* embedding_table,
const __nv_bfloat16* pos_table,
pPromptTuningParam<__nv_bfloat16> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
#endif
template<typename T>
__global__ void inputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<T> param)
{
// 1. Copy the input ids to output ids and transpose output ids to [seq_len, batch_size, beam_width].
// 2. Embedding lookup by input ids and concat with soft prompt. The axis of concatenation is on axis of seq_len.
// Assume batch size is 2 and prompts are [[t1, t2], [t3], [t4, t5]], input_ids are [[s1, s2], [s3], [s4]]
// then the order of output_ids is
// [ [?, ?, s1, s2]
// [?, s3, padding, padding]
// [?, ?, s4, padding] ]
// and the order of embedding is
// [ [t1, t2, s1, s2]
// [t3, s3, padding, padding]
// [t4, t5, s4, padding] ]
// where "?" means undefined values and we should attach it.
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < param.batch_size * param.beam_width * (param.max_prefix_soft_prompt_length + param.max_input_length)
* param.hidden_units;
index += blockDim.x * gridDim.x) {
// transpose the input_ids [batch, length] (part of [batch, beam, max_input_length]) to
// output_ids [length, batch, beam].
// ouptut_ids need to add padding in the beginning for soft prompting.
if (index < param.batch_size * param.beam_width * param.max_input_length) {
int tmp_index = index;
const int seq_id = tmp_index % param.max_input_length;
tmp_index = (tmp_index - seq_id) / param.max_input_length;
const int beam_id = tmp_index % param.beam_width;
tmp_index = (tmp_index - beam_id) / param.beam_width;
const int batch_id = tmp_index % param.batch_size;
if (seq_id < param.max_input_length) {
param.output_ids[(param.prefix_soft_prompt_lengths[batch_id] + seq_id) * param.batch_size
* param.beam_width
+ batch_id * param.beam_width + beam_id] = param.input_ids[index];
}
}
// embedding lookup from word ids [batch, beam, length] (part of [batch, beam, max_input_length]), [vocab,
// hidden] and [batch, max_prefix_soft_prompt_length, hidden] to generate embedding [batch, beam, length +
// max_prefix_soft_prompt_length, hidden]
int tmp_index = index;
const int hidden_id = tmp_index % param.hidden_units;
tmp_index = (tmp_index - hidden_id) / param.hidden_units;
const int seq_id = tmp_index % (param.max_prefix_soft_prompt_length + param.max_input_length);
tmp_index = (tmp_index - seq_id) / (param.max_prefix_soft_prompt_length + param.max_input_length);
const int beam_id = tmp_index % param.beam_width;
tmp_index = (tmp_index - beam_id) / param.beam_width;
const int batch_id = tmp_index % param.batch_size;
const int64_t hidden_units = param.hidden_units;
T embedding =
(seq_id < param.prefix_soft_prompt_lengths[batch_id]) ?
(T)param.prefix_soft_prompt_embedding[batch_id * param.max_prefix_soft_prompt_length * hidden_units
+ seq_id * hidden_units + hidden_id] :
param.embedding_table[param.input_ids[batch_id * param.beam_width * param.max_input_length
+ beam_id * param.max_input_length
+ (seq_id - param.prefix_soft_prompt_lengths[batch_id])]
* hidden_units
+ hidden_id];
T pos_embed = param.pos_table == nullptr ?
(T)0.0f :
param.pos_table[(param.start_step + seq_id - 1) * hidden_units + hidden_id];
param.from_tensor[index] = embedding + pos_embed;
if (seq_id == 0 && hidden_id == 0) {
param.input_lengths[batch_id * param.beam_width + beam_id] += param.prefix_soft_prompt_lengths[batch_id];
}
}
}
template<typename T>
void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<T> param)
{
dim3 grid(min(param.batch_size * param.beam_width * (param.max_input_length + param.max_prefix_soft_prompt_length),
65536));
dim3 block(min(param.hidden_units, 512));
inputIdsEmbeddingLookupPosEncodingSoftPrompt<T><<<grid, block, 0, param.stream>>>(param);
}
template void
invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<float> param);
template void
invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<half> param);
#ifdef ENABLE_BF16
template void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(
inputIdsEmbeddingLookupPosEncodingSoftPromptParam<__nv_bfloat16> param);
#endif
// TODO Add half2 implementation
template<typename T>
__global__ void transposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2)
{
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < dim0 * dim1 * dim2) {
const int input_dim2_index = index % dim2;
index = (index - input_dim2_index) / dim2;
const int input_dim1_index = index % dim1;
index = (index - input_dim1_index) / dim1;
const int input_dim0_index = index % dim0;
out[input_dim1_index * dim0 * dim2 + input_dim0_index * dim2 + input_dim2_index] =
in[input_dim0_index * dim1 * dim2 + input_dim1_index * dim2 + input_dim2_index];
}
}
template<typename T>
void invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(dim0 * dim1 * dim2 / 512.)));
transposeAxis01<<<grid, block, 0, stream>>>(out, in, dim0, dim1, dim2);
}
template void
invokeTransposeAxis01(float* out, float* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);
template void
invokeTransposeAxis01(half* out, half* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);
template void
invokeTransposeAxis01(int* out, int* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);
template<typename T>
__global__ void transposeAxis01(T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1)
{
// out: [dim1, dim0]
// in: [dim0, dim1]
// in_skipping_dim1: [dim1]
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < dim0 * dim1) {
const int input_dim1_index = index % dim1;
index = (index - input_dim1_index) / dim1;
const int input_dim0_index = index % dim0;
const int in_offset = in_skipping_dim1 == nullptr ? 0 : in_skipping_dim1[input_dim1_index] * dim1;
out[input_dim1_index * dim0 + input_dim0_index] = in[in_offset + input_dim0_index * dim1 + input_dim1_index];
}
}
template<typename T>
void invokeTransposeAxis01(
T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(dim0 * dim1 / 512.)));
transposeAxis01<<<grid, block, 0, stream>>>(out, in, in_skipping_dim1, dim0, dim1);
}
template void invokeTransposeAxis01(
int* out, int* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);
template<typename T, bool PREFIX_PROMPT>
__global__ void buildDecoderAttentionMaskKernel(T* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int max_seq_len,
const int max_prompt_length)
{
// sequence_lengths: [batch_size]
// attention_mask: [batch_size, 1, max_seq_len, max_seq_len + max_prompt_length]
const int max_prompt_seq_length = max_seq_len + max_prompt_length;
const int mask_size_per_seq = max_seq_len * max_prompt_seq_length;
attention_mask += blockIdx.x * mask_size_per_seq;
const int seq_length = sequence_lengths[blockIdx.x];
const int prompt_length = PREFIX_PROMPT ? prefix_prompt_lengths[blockIdx.x] : 0;
for (int i = threadIdx.x; i < mask_size_per_seq; i += blockDim.x) {
int row_id = i / max_prompt_seq_length;
int col_id = i % max_prompt_seq_length;
if (row_id < seq_length && col_id <= (row_id + prompt_length)) {
attention_mask[i] = (T)(1.0f);
}
else {
attention_mask[i] = (T)(0.0f);
}
}
}
template<typename T>
void invokeBuildDecoderAttentionMask(T* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
cudaStream_t stream)
{
if (max_prompt_length == 0) {
buildDecoderAttentionMaskKernel<T, false><<<batch_size, 256, 0, stream>>>(
attention_mask, sequence_lengths, prefix_prompt_lengths, max_seq_len, max_prompt_length);
}
else {
buildDecoderAttentionMaskKernel<T, true><<<batch_size, 256, 0, stream>>>(
attention_mask, sequence_lengths, prefix_prompt_lengths, max_seq_len, max_prompt_length);
}
}
template void invokeBuildDecoderAttentionMask(float* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
cudaStream_t stream);
template void invokeBuildDecoderAttentionMask(half* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBuildDecoderAttentionMask(__nv_bfloat16* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
cudaStream_t stream);
#endif
#ifdef ENABLE_FP8
template void invokeBuildDecoderAttentionMask(__nv_fp8_e4m3* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
cudaStream_t stream);
#endif
template<typename T>
__launch_bounds__(1024, 1) __global__ void lookupHiddenStateOfLastToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int max_input_length,
const int batch_size,
const int hidden_units)
{
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * hidden_units;
index += blockDim.x * gridDim.x) {
const int col_index = index % hidden_units;
const int batch_id = index / hidden_units;
from_tensor[index] = hidden_state[batch_id * max_input_length * hidden_units
+ (input_lengths[batch_id] - 1) * hidden_units + col_index];
}
}
template<typename T>
void invokeLookupHiddenStateOfLastToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int max_input_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream)
{
const int grid_size = (int)(ceil(batch_size * hidden_units / 1024.));
dim3 grid(min(grid_size, 65536));
dim3 block(min(hidden_units, 1024));
lookupHiddenStateOfLastToken<T><<<grid, block, 0, stream>>>(
from_tensor, hidden_state, input_lengths, max_input_length, batch_size, hidden_units);
}
template void invokeLookupHiddenStateOfLastToken(float* from_tensor,
const float* hidden_state,
const int* input_lengths,
const int max_input_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
template void invokeLookupHiddenStateOfLastToken(half* from_tensor,
const half* hidden_state,
const int* input_lengths,
const int max_input_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeLookupHiddenStateOfLastToken(__nv_bfloat16* from_tensor,
const __nv_bfloat16* hidden_state,
const int* input_lengths,
const int max_input_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
#endif
template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int max_input_length)
{
if (threadIdx.x == 0) {
tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];
if (PREFIX_PROMPT) {
tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];
}
}
for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {
tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =
input_ids[blockIdx.x * max_input_length + index];
}
}
void invokeTileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream)
{
dim3 grid(batch_size, beam_width);
dim3 block(min(1024, max_input_length));
if (prefix_prompt_lengths != nullptr) {
tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,
tiled_input_lengths,
tiled_prompt_lengths,
input_ids,
input_lengths,
prefix_prompt_lengths,
max_input_length);
}
else {
tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,
tiled_input_lengths,
tiled_prompt_lengths,
input_ids,
input_lengths,
prefix_prompt_lengths,
max_input_length);
}
}
void invokeTileGptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
const int* input_ids,
const int* input_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream)
{
invokeTileGptPromptInputs(tiled_input_ids,
tiled_input_lengths,
nullptr,
input_ids,
input_lengths,
nullptr,
batch_size,
beam_width,
max_input_length,
stream);
}
void setSeqLimitLen(uint32_t* seq_len_d, Tensor seq_len, int limit_len_offset, int batch_size)
{
std::vector<uint32_t> seq_len_h(batch_size);
for (int i = 0; i < batch_size; i++) {
seq_len_h[i] = seq_len.getPtr<uint32_t>()[i] + limit_len_offset;
}
cudaH2Dcpy(seq_len_d, seq_len_h.data(), batch_size);
}
template<int TB_SIZE>
__global__ void
find_context_dups(int* shared_contexts, const int* input_ids, const size_t batch_size, const size_t input_seq_len)
{
/* We compare all context pairs (i, j), with i (tgt) < j (src) , to detect duplicate
* inputs. If there's a match between i and j, we store i at the
* j-th position of shared_context. So that we know that j can be
* represented by i. shared_contexts is initialized like shared_contexts[i] = i
* and when there's a match, we actually use shared_contexts[j] = min(shared_contexts[j], i)
* so that in the end, shared_contexts effectively contains an index
* to the match with the lowest index context.
* Note that shared_contexts[i] <= i, a property that will be used when uncompacting
* inputs.
*/
typedef cub::BlockReduce<int, TB_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ bool match;
/* Each block is responsible for a (i, j) pair. To map the block space to
* the i < j space, we need to convert a linear addressing to a triangle, of
* size (batch_size * (batch_size - 1)) / 2
* For more information, check https://en.wikipedia.org/wiki/Triangular_number
*/
// blockIdx = [0, 1, 2, ... n(n-1)/2] -> base_index = [0, 1, 1, 2, 2, 2, 3, 3, 3, 3, ..., n - 2]
const int base_index = floorf(0.5f * (sqrtf(1 + 8 * blockIdx.x) - 1));
const int src_idx = base_index + 1; // base_index \in [1, batch_size)
const int rev_base_index = base_index * (base_index + 1) / 2;
const int tgt_idx = blockIdx.x - rev_base_index; // tgt_idx \in [0, src_idx)
const int padded_length = TB_SIZE * ((input_seq_len + TB_SIZE - 1) / TB_SIZE);
int sum = 0;
for (int i = threadIdx.x; i < padded_length; i += TB_SIZE) {
int compare =
(i >= input_seq_len) ? 1 : input_ids[tgt_idx * input_seq_len + i] == input_ids[src_idx * input_seq_len + i];
sum = BlockReduce(temp_storage).Sum(compare);
if (threadIdx.x == 0) {
match = (sum == TB_SIZE);
}
__syncthreads();
if (!match) {
break;
}
}
if (threadIdx.x == 0 && match) {
atomicMin(&shared_contexts[src_idx], tgt_idx);
}
}
constexpr int DUPS_INDICES_BLOCK_SIZE = 128;
__global__ void generate_dups_indices(int* batch_to_compact,
int* compact_to_batch,
int* compact_size,
const int* shared_contexts,
const size_t batch_size,
const size_t input_seq_len)
{
const int padded_batchsize = blockDim.x * ((batch_size + blockDim.x - 1) / blockDim.x);
typedef cub::BlockScan<int, DUPS_INDICES_BLOCK_SIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ int scan_offset;
int scan = 0;
for (int batch = threadIdx.x; batch < padded_batchsize; batch += blockDim.x) {
bool masked = (batch >= batch_size);
bool first_iter = batch < blockDim.x;
int is_first_occur = masked ? 0 : shared_contexts[batch] == batch;
BlockScan(temp_storage).ExclusiveSum(is_first_occur, scan);
if (!masked && is_first_occur) {
int compact_idx = scan + (first_iter ? 0 : scan_offset);
// Context rep. writes initial index
batch_to_compact[batch] = compact_idx;
compact_to_batch[compact_idx] = batch;
}
if (threadIdx.x == blockDim.x - 1) {
scan_offset = scan + is_first_occur + (first_iter ? 0 : scan_offset);
}
__syncthreads();
if (!masked && !is_first_occur) {
// Fill the rest of batch_to_compact based on what rep. wrote
const int src_idx = batch_to_compact[shared_contexts[batch]];
batch_to_compact[batch] = src_idx;
}
}
if (threadIdx.x == 0) {
*compact_size = scan_offset;
}
}
__global__ void init_shared_contexts(int* shared_contexts, const size_t batch_size)
{
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= batch_size) {
return;
}
shared_contexts[global_idx] = global_idx;
}
void invokeFindContextDups(int* shared_contexts,
int* batch_to_compact,
int* compact_to_batch,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t input_seq_len,
cudaStream_t stream)
{
dim3 block{512};
dim3 grid{((int)batch_size + block.x - 1) / block.x};
init_shared_contexts<<<grid, block, 0, stream>>>(shared_contexts, batch_size);
grid = dim3{(unsigned int)(batch_size * (batch_size - 1)) / 2};
if (input_seq_len <= 128) {
block = 128;
find_context_dups<128><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
}
else {
block = 256;
find_context_dups<256><<<grid, block, 0, stream>>>(shared_contexts, input_ids, batch_size, input_seq_len);
}
generate_dups_indices<<<1, DUPS_INDICES_BLOCK_SIZE, 0, stream>>>(
batch_to_compact, compact_to_batch, compact_size, shared_contexts, batch_size, input_seq_len);
}
template<typename T>
__global__ void compact_inputs(T* compact_input,
T* compact_attention_mask,
int* compact_input_lengths,
const T* decoder_input,
const T* decoder_mask,
const int* input_lengths,
const int* compact_idx,
size_t compact_size,
size_t seq_len,
size_t hidden_dimension)
{
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx < compact_size * seq_len * hidden_dimension) {
const int h_id = global_idx % hidden_dimension;
const int seq_id = (global_idx / hidden_dimension) % seq_len;
const int batch_id = global_idx / (hidden_dimension * seq_len);
compact_input[global_idx] = decoder_input[(compact_idx[batch_id] * seq_len + seq_id) * hidden_dimension + h_id];
}
if (global_idx < compact_size * seq_len * seq_len) {
const int seq1_id = global_idx % seq_len;
const int seq2_id = (global_idx / seq_len) % seq_len;
const int batch_id = global_idx / (seq_len * seq_len);
compact_attention_mask[global_idx] =
decoder_mask[(compact_idx[batch_id] * seq_len + seq2_id) * seq_len + seq1_id];
}
if (global_idx < compact_size) {
compact_input_lengths[global_idx] = input_lengths[compact_idx[global_idx]];
}
}
template<typename T>
void invokeCompactInputs(T* compact_input,
T* compact_attention_mask,
int* compact_input_lengths,
const T* decoder_input,
const T* decoder_mask,
const int* input_lengths,
const int* compact_idx,
size_t compact_size,
size_t seq_len,
size_t hidden_dimension,
cudaStream_t stream)
{
/* Compact relevant decoder_layer inputs based on the identical contexts.
* For example, decoder_input is [batch_size, seq_len, H]. It's compacted
* into compact_input [compact_size, seq_len, H] such that
* compact_input[i, ...] = decoder_input[compact_idx[i], ...] */
const size_t elems_n = compact_size * seq_len * max(hidden_dimension, seq_len);
const dim3 blockDim(512);
const dim3 gridDim((elems_n + 512 - 1) / 512);
compact_inputs<T><<<gridDim, blockDim, 0, stream>>>(compact_input,
compact_attention_mask,
compact_input_lengths,
decoder_input,
decoder_mask,
input_lengths,
compact_idx,
compact_size,
seq_len,
hidden_dimension);
}
#define INSTANTIATE_INVOKE_COMPACT_INPUTS(T) \
template void invokeCompactInputs<T>(T * compact_input, \
T * compact_attention_mask, \
int* compact_input_lengths, \
const T* decoder_input, \
const T* decoder_mask, \
const int* input_lengths, \
const int* compact_idx, \
size_t compact_size, \
size_t seq_len, \
size_t hidden_dimension, \
cudaStream_t stream)
INSTANTIATE_INVOKE_COMPACT_INPUTS(half);
INSTANTIATE_INVOKE_COMPACT_INPUTS(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_COMPACT_INPUTS(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_COMPACT_INPUTS
template<typename T>
__global__ void uncompact_outputs(T* uncompact_buffer,
const T* compact_buffer,
const int* batch_to_compact_idx,
size_t batch_size,
size_t buffer_stride)
{
/* Uncompact a buffer IN of size [Compact, Stride] into OUT of size [Batch, Stride]
* so that \forall i, OUT[i, :] = IN[batch_to_compact_idx[i], :]
*/
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (global_idx >= batch_size * buffer_stride) {
return;
}
const int stride_idx = global_idx % buffer_stride;
const int batch_idx = global_idx / buffer_stride;
const int src = batch_to_compact_idx[batch_idx];
uncompact_buffer[global_idx] = compact_buffer[src * buffer_stride + stride_idx];
}
template<typename T>
void invokeUnCompactOutputs(T* uncompact_buffer,
const T* compact_buffer,
const int* batch_to_compact_idx,
size_t batch_size,
size_t buffer_stride,
cudaStream_t stream)
{
const size_t num_elems = batch_size * buffer_stride;
const dim3 blockDim(1024);
const dim3 gridDim((num_elems + blockDim.x - 1) / blockDim.x);
uncompact_outputs<T><<<gridDim, blockDim, 0, stream>>>(
uncompact_buffer, compact_buffer, batch_to_compact_idx, batch_size, buffer_stride);
}
#define INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(T) \
template void invokeUnCompactOutputs(T* uncompact_buffer, \
const T* compact_buffer, \
const int* batch_to_compact_idx, \
size_t batch_size, \
size_t buffer_stride, \
cudaStream_t stream)
INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(half);
INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_UNCOMPACT_OUTPUTS
template<typename T>
__global__ void uncompact_caches(T* uncompact_k_cache,
T* uncompact_v_cache,
const T* compact_k_cache,
const T* compact_v_cache,
const int* batch_to_compact_idx,
size_t batch_size,
size_t num_heads,
size_t max_seq_len,
size_t seq_len,
size_t size_per_head,
size_t local_batch_size,
size_t ite)
{
const int hidden_dimension = num_heads * size_per_head;
const int num_elems_per_batch = seq_len * hidden_dimension;
const int num_elems_cache = batch_size * num_elems_per_batch;
const int x_size = 16 / sizeof(T);
for (int global_idx = blockIdx.x * blockDim.x + threadIdx.x; global_idx < 2 * num_elems_cache;
global_idx += blockDim.x * gridDim.x) {
const bool handle_k = global_idx < num_elems_cache;
const T* const cache_src = handle_k ? compact_k_cache : compact_v_cache;
T* const cache_dst = handle_k ? uncompact_k_cache : uncompact_v_cache;
const int idx = handle_k ? global_idx : global_idx - num_elems_cache;
const int src_offset = idx % num_elems_per_batch;
const int batch_idx = idx / num_elems_per_batch;
const int batch_src = batch_to_compact_idx[batch_idx] - ite * local_batch_size;
if (batch_src < 0 || batch_src >= local_batch_size) {
continue;
}
int dst_offset;
if (handle_k) {
const int i0 = idx % (x_size * seq_len);
const int i1 = (idx / (x_size * seq_len)) % (num_heads * size_per_head / x_size);
dst_offset = i1 * max_seq_len * x_size + i0;
}
else {
const int i0 = idx % (size_per_head * seq_len);
const int i1 = (idx / (size_per_head * seq_len)) % (num_heads);
dst_offset = i1 * max_seq_len * size_per_head + i0;
}
cache_dst[batch_idx * max_seq_len * hidden_dimension + dst_offset] =
cache_src[batch_src * num_elems_per_batch + src_offset];
}
}
template<typename T>
void invokeUnCompactCaches(T* uncompact_k_cache,
T* uncompact_v_cache,
const T* compact_k_cache,
const T* compact_v_cache,
const int* batch_to_compact_idx,
size_t batch_size,
size_t num_heads,
size_t max_seq_len,
size_t seq_len,
size_t size_per_head,
size_t local_batch_size,
size_t ite,
cudaStream_t stream)
{
const dim3 blockDim(512);
const dim3 gridDim(1024);
uncompact_caches<T><<<gridDim, blockDim, 0, stream>>>(uncompact_k_cache,
uncompact_v_cache,
compact_k_cache,
compact_v_cache,
batch_to_compact_idx,
batch_size,
num_heads,
max_seq_len,
seq_len,
size_per_head,
local_batch_size,
ite);
}
#define INSTANTIATE_INVOKE_UNCOMPACT_CACHES(T) \
template void invokeUnCompactCaches(T* uncompact_k_cache, \
T* uncompact_v_cache, \
const T* compact_k_cache, \
const T* compact_v_cache, \
const int* batch_to_compact_idx, \
size_t batch_size, \
size_t num_heads, \
size_t max_seq_len, \
size_t seq_len, \
size_t size_per_head, \
size_t local_batch_size, \
size_t ite, \
cudaStream_t stream)
INSTANTIATE_INVOKE_UNCOMPACT_CACHES(half);
INSTANTIATE_INVOKE_UNCOMPACT_CACHES(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_UNCOMPACT_CACHES(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_UNCOMPACT_CACHES
template<bool PREFIX_PROMPT>
__global__ void update_padding_count(int* total_padding_count,
const int* input_lengths,
const int* tiled_prompt_lengths,
size_t max_input_length,
size_t max_prompt_length,
size_t batch_size,
size_t beam_width)
{
const int gidx = blockIdx.x * blockDim.x + threadIdx.x;
if (gidx >= batch_size * beam_width) {
return;
}
const int batch_idx = gidx / beam_width;
total_padding_count[gidx] +=
PREFIX_PROMPT ? (max_input_length + max_prompt_length - input_lengths[batch_idx] - tiled_prompt_lengths[gidx]) :
(max_input_length - input_lengths[batch_idx]);
}
void invokeUpdatePaddingCount(int* total_padding_count,
const int* input_lengths,
const int* tiled_prompt_lengths,
size_t max_input_length,
size_t max_prompt_length,
size_t batch_size,
size_t beam_width,
cudaStream_t stream)
{
dim3 blockSize(256);
dim3 gridSize((batch_size * beam_width + blockSize.x - 1) / blockSize.x);
if (tiled_prompt_lengths != nullptr) {
update_padding_count<true><<<gridSize, blockSize, 0, stream>>>(total_padding_count,
input_lengths,
tiled_prompt_lengths,
max_input_length,
max_prompt_length,
batch_size,
beam_width);
}
else {
update_padding_count<false><<<gridSize, blockSize, 0, stream>>>(total_padding_count,
input_lengths,
tiled_prompt_lengths,
max_input_length,
max_prompt_length,
batch_size,
beam_width);
}
}
template<bool PREFIX_PROMPT>
__global__ void mask_padding_tokens(bool* masked_tokens,
const int* input_lengths,
const int* tiled_prefix_prompt_lengths,
const size_t memory_len,
const size_t max_input_length,
const size_t initial_step,
size_t beam_width)
{
const int seq_len = PREFIX_PROMPT ?
(input_lengths[blockIdx.x / beam_width] + tiled_prefix_prompt_lengths[blockIdx.x]) :
input_lengths[blockIdx.x / beam_width];
for (int step = initial_step + seq_len + threadIdx.x; step < initial_step + max_input_length; step += blockDim.x) {
masked_tokens[blockIdx.x * memory_len + step % memory_len] = true;
}
}
void invokeMaskPaddingTokens(bool* masked_tokens,
const int* input_lengths,
const int* tiled_prefix_prompt_lengths,
const size_t memory_len,
const size_t max_input_length,
const size_t initial_step,
size_t batch_size,
size_t beam_width,
cudaStream_t stream)
{
dim3 blockSize(128);
dim3 gridSize(batch_size * beam_width);
if (tiled_prefix_prompt_lengths != nullptr) {
mask_padding_tokens<true><<<gridSize, blockSize, 0, stream>>>(masked_tokens,
input_lengths,
tiled_prefix_prompt_lengths,
memory_len,
max_input_length,
initial_step,
beam_width);
}
else {
mask_padding_tokens<false><<<gridSize, blockSize, 0, stream>>>(masked_tokens,
input_lengths,
tiled_prefix_prompt_lengths,
memory_len,
max_input_length,
initial_step,
beam_width);
}
}
template<typename T>
__global__ void sum_length_dimension(
float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{
const int bidx = blockIdx.x;
for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {
float accum = 0.0f;
for (int step = 0; step < input_length; step++) {
accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);
}
out_buf[bidx * hidden_dim + hidx] = accum;
}
}
template<typename T>
void invokeSumLengthDimension(float* out_buf,
const T* in_buf,
const size_t batch_size,
const size_t input_length,
const size_t hidden_dim,
cudaStream_t stream)
{
dim3 gridSize(batch_size);
dim3 blockSize(256);
sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}
#define INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(T) \
template void invokeSumLengthDimension(float* out_buf, \
const T* in_buf, \
const size_t batch_size, \
const size_t input_length, \
const size_t hidden_dim, \
cudaStream_t stream)
INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(half);
INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(float);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_SUM_LENGTH_DIMENSION
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <unordered_map>
#include "src/fastertransformer/utils/Tensor.h"
#include "src/fastertransformer/utils/memory_utils.h"
namespace fastertransformer {
template<typename T>
struct inputIdsEmbeddingLookupPosEncodingSoftPromptParam {
T* from_tensor;
int* output_ids;
int* input_lengths;
const T* embedding_table;
const T* pos_table;
const float* prefix_soft_prompt_embedding;
const int* prefix_soft_prompt_lengths;
int* input_ids;
int start_step;
int max_input_length;
int max_prefix_soft_prompt_length;
int batch_size;
int beam_width;
int hidden_units;
cudaStream_t stream;
};
template<typename T>
struct pPromptTuningParam {
// Batch number of ptrs, each ptr is the ptr of the specific p/prompt tuning weights for this sequence
const T** p_prompt_tuning_batch_weights = nullptr;
// The start id of p_prompt_tuning token ids (based on the tokenizer)
// PROMPT_0 --> p_prompt_tuning_id_start; PROMPT_1 --> p_prompt_tuning_id_start + 1; ...
const int p_prompt_tuning_id_start = 0;
// Request prompt embeddding's max length
const int request_prompt_max_length = 0;
// Whether or not use the request prompt embeddings
const bool use_request_p_prompt_embedding = false;
// Request prompt embeddings
const T* request_prompt_embedding = nullptr;
};
template<typename T>
void invokeInputIdsEmbeddingLookupPosEncoding(T* from_tensor,
int* output_ids,
const T* embedding_table,
const T* pos_table,
pPromptTuningParam<T> prompt_param,
const int* input_ids,
const int start_step,
const int length,
const int max_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
template<typename T>
void invokeInputIdsEmbeddingLookupPosEncodingSoftPrompt(inputIdsEmbeddingLookupPosEncodingSoftPromptParam<T> param);
template<typename T>
void invokeTransposeAxis01(T* out, T* in, const int dim0, const int dim1, const int dim2, cudaStream_t stream);
template<typename T>
void invokeTransposeAxis01(
T* out, T* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);
template<typename T>
void invokeBuildDecoderAttentionMask(T* attention_mask,
const int* sequence_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int max_seq_len,
const int max_prompt_length,
cudaStream_t stream);
template<typename T>
void invokeLookupHiddenStateOfLastToken(T* from_tensor,
const T* hidden_state,
const int* input_lengths,
const int max_input_length,
const int batch_size,
const int hidden_units,
cudaStream_t stream);
void invokeTileGptPromptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
int* tiled_prompt_lengths,
const int* input_ids,
const int* input_lengths,
const int* prefix_prompt_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream);
void invokeTileGptInputs(int* tiled_input_ids,
int* tiled_input_lengths,
const int* input_ids,
const int* input_lengths,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream);
void invokeFindContextDups(int* shared_contexts,
int* batch_to_compact,
int* compact_to_batch,
int* compact_size,
const int* input_ids,
const size_t batch_size,
const size_t input_seq_len,
cudaStream_t stream = 0);
template<typename T>
void handleOptArg(TensorMap* input_tensors, const std::string& arg_name, T* d_ptr, T default_value, size_t size)
{
if (input_tensors->isExist(arg_name)) {
FT_CHECK(input_tensors->at(arg_name).size() == size);
cudaH2Dcpy(d_ptr, input_tensors->at(arg_name).getPtr<const T>(), size);
}
else {
deviceFill(d_ptr, size, default_value);
}
}
void setSeqLimitLen(uint32_t* seq_len_d, Tensor seq_len, int limit_len_offset, int batch_size);
template<typename T>
void invokeCompactInputs(T* compact_input,
T* compact_attention_mask,
int* compact_input_lengths,
const T* decoder_input,
const T* decoder_mask,
const int* input_lengths,
const int* compact_idx,
size_t compact_size,
size_t seq_len,
size_t hidden_dimension,
cudaStream_t stream = 0);
template<typename T>
void invokeUnCompactOutputs(T* uncompact_buffer,
const T* compact_buffer,
const int* batch_to_compact_idx,
size_t batch_size,
size_t buffer_stride,
cudaStream_t stream = 0);
template<typename T>
void invokeUnCompactCaches(T* uncompact_k_cache,
T* uncompact_v_cache,
const T* compact_k_cache,
const T* compact_v_cache,
const int* batch_to_compact_idx,
size_t batch_size,
size_t num_heads,
size_t max_seq_len,
size_t seq_len,
size_t size_per_head,
size_t local_batch_size,
size_t ite,
cudaStream_t stream = 0);
void invokeUpdatePaddingCount(int* total_padding_count,
const int* input_lengths,
const int* tiled_prompt_lengths,
size_t max_input_length,
size_t max_prompt_length,
size_t batch_size,
size_t beam_width,
cudaStream_t stream = 0);
inline void invokeUpdatePaddingCount(int* total_padding_count,
const int* input_lengths,
size_t max_input_length,
size_t batch_size,
size_t beam_width,
cudaStream_t stream = 0)
{
invokeUpdatePaddingCount(
total_padding_count, input_lengths, (const int*)nullptr, max_input_length, 0, batch_size, beam_width, stream);
}
void invokeMaskPaddingTokens(bool* masked_tokens,
const int* input_lengths,
const int* tiled_prefix_prompt_lengths,
const size_t memory_len,
const size_t max_input_length,
const size_t initial_step,
size_t batch_size,
size_t beam_width,
cudaStream_t stream = 0);
inline void invokeMaskPaddingTokens(bool* masked_tokens,
const int* input_lengths,
const size_t memory_len,
const size_t max_input_length,
const size_t initial_step,
size_t batch_size,
size_t beam_width,
cudaStream_t stream = 0)
{
invokeMaskPaddingTokens(masked_tokens,
input_lengths,
(const int*)nullptr,
memory_len,
max_input_length,
initial_step,
batch_size,
beam_width,
stream);
}
template<typename T>
void invokeSumLengthDimension(float* out_buf,
const T* in_buf,
const size_t batch_size,
const size_t input_length,
const size_t hidden_dim,
cudaStream_t stream = 0);
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/fastertransformer/kernels/logprob_kernels.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/utils/logger.h"
namespace fastertransformer {
template<typename T>
__global__ void log_probs_kernel(float* log_probs,
const T* logits,
const int* ids,
const int* lengths,
const size_t max_input_length,
const size_t batch_size,
const size_t vocab_size,
const size_t vocab_size_padded,
bool batch_first)
{
// Calculate the log probability from logits.
// log_probs[t, :] = log(softmax(logits))[ids[t + 1, :]]
//
// log_probs: [max_length - 1, batch_size] or [batch_size, max_length -1],
// log probabilities of each token.
// logits: [max_length, batch_size, vocab_size_padded] or [batch_size, max_length, vocab_size_padded]
// lengths: [batch_size], sequence lengths
// ids: [max_length, batch_size], token ids.
// batch_size: [1], batch_size. in case of beam > 1, batch x beam.
// vocab_size: [1], vocab_size,
// vocab_size: [1], vocab_size_padded, padded vocab size.
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
int tidx = threadIdx.x; // vocab dim
int bidx = batch_first ? blockIdx.x : blockIdx.y; // batch dim
int step = batch_first ? blockIdx.y : blockIdx.x; // step dim
__shared__ float s_max_logit;
if (bidx < batch_size && step < lengths[bidx] - 1) {
// reposition logits to data for the current batch.
int step_offset = batch_first ? step * vocab_size_padded : step * batch_size * vocab_size_padded;
int batch_offset = batch_first ? bidx * max_input_length * vocab_size_padded : bidx * vocab_size_padded;
logits += step_offset + batch_offset;
// Find max(logits).
float local_max = -MAX_T_VAL;
float val = -MAX_T_VAL;
for (int i = tidx; i < vocab_size; i += blockDim.x) {
val = static_cast<float>(logits[i]);
local_max = fmax(local_max, val);
}
float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
if (tidx == 0) {
s_max_logit = max_val;
}
__syncthreads();
// Calculate the denominator: sum_i exp(logits[i])
float local_sum_exp = 0.0f;
for (int i = tidx; i < vocab_size; i += blockDim.x) {
val = __expf(static_cast<float>(logits[i]) - s_max_logit);
local_sum_exp += val;
}
float sum_exp = blockDim.x <= 32 ? warpReduceSum(local_sum_exp) : blockReduceSum<float>(local_sum_exp);
if (tidx == 0) {
int idx = batch_first ? step + bidx * (max_input_length - 1) : step * batch_size + bidx;
// log_probs[step, ...] is the log probability of a token at step t + 1.
int token_idx = batch_first ? step + 1 + bidx * max_input_length : (step + 1) * batch_size + bidx;
log_probs[idx] = static_cast<float>(logits[ids[token_idx]]) - s_max_logit - __logf(sum_exp + 1e-9f);
}
}
}
__global__ void accumulate_log_probs(float* cum_log_probs,
const float* log_probs,
const int* lengths,
const size_t max_input_length,
const size_t batch_size,
const bool batch_first)
{
// Accumulate the log probability along with the sequence dimension.
// cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
//
// cum_log_probs: [batch_size], cumulative log probability
// log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
// log probability of each token
// lengths: [batch_size], sequence lengths
// batch_size: [1], batch_size. in case of beam > 1, batch x beam.
int bidx = blockIdx.x; // batch dim
int tidx = threadIdx.x; // step dim
if (bidx < batch_size) {
int length = lengths[bidx];
// reposition logits to data for the current batch.
log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
int stride = batch_first ? 1 : batch_size; // stride along with seq dim.
float local_accum = 0.0f;
for (int step = tidx; step < length - 1; step += blockDim.x) {
local_accum += static_cast<float>(log_probs[step * stride]);
}
float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);
if (tidx == 0) {
cum_log_probs[bidx] = accum;
}
}
}
template<typename T>
void invokeLogProbFromLogits(float* cum_log_probs,
const T* logits,
const int* input_ids,
const int* input_lengths,
const size_t max_input_length,
const size_t batch_size,
const size_t vocab_size,
const size_t vocab_size_padded,
void* workspace,
const size_t workspace_size,
cudaStream_t stream,
const bool batch_first)
{
// A batched version of log prob computation.
//
// cum_log_probs: [batch_size]
// logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
// input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
// input_lengths: [batch_size]
// workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
// block_size should be multiple of 32 to use warpReduceMax.
const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
assert(block_size % 32 == 0);
assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
assert(vocab_size <= vocab_size_padded);
float* log_probs = reinterpret_cast<float*>(workspace);
int gx = batch_first ? batch_size : max_input_length - 1;
int gy = batch_first ? max_input_length - 1 : batch_size;
dim3 grid(gx, gy);
log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,
logits,
input_ids,
input_lengths,
max_input_length,
batch_size,
vocab_size,
vocab_size_padded,
batch_first);
accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(
cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}
template void invokeLogProbFromLogits(float* cum_log_probs,
const float* logits,
const int* input_ids,
const int* input_lengths,
const size_t max_input_length,
const size_t batch_size,
const size_t vocab_size,
const size_t vocab_size_padded,
void* workspace,
const size_t workspace_size,
cudaStream_t stream,
const bool batch_first);
template void invokeLogProbFromLogits(float* cum_log_probs,
const half* logits,
const int* input_ids,
const int* input_lengths,
const size_t max_input_length,
const size_t batch_size,
const size_t vocab_size,
const size_t vocab_size_padded,
void* workspace,
const size_t workspace_size,
cudaStream_t stream,
const bool batch_first);
} // end of namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace fastertransformer {
template<typename T>
void invokeLogProbFromLogits(float* cum_log_probs,
const T* logits,
const int* input_ids,
const int* input_lengths,
const size_t max_input_length,
const size_t batch_size,
const size_t vocab_size,
const size_t vocab_size_padded,
void* workspace,
const size_t workspace_size,
cudaStream_t stream,
const bool batch_first = false);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/fastertransformer/kernels/online_softmax_beamsearch_kernels.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/utils/cuda_utils.h"
namespace fastertransformer {
#define DO_SPLIT_SMALL_TOP_K_SOFTMAX
static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
#define TOPK_FP16_STORAGE 0
template<typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
{
// score = log(prob) / (length)^length_penalty.
if (length_penalty == 0.0f || length == 1) {
return log_prob;
}
return log_prob / static_cast<T>(powf(length, length_penalty));
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
{
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
TopK<T, MAX_K> partial;
if (thread_id == 0) {
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -FLT_MAX;
}
int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++) {
partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
}
index = block_id * MAX_K;
for (int i = 0; i < MAX_K; i++) {
id_buf[index + i] = partial.p[i];
}
}
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf,
const T* __restrict topk_tmp_val_buf,
int* __restrict id_buf,
T* __restrict val_buf)
{
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
TopK<T, MAX_K> partial;
if (thread_id == 0) {
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -FLT_MAX;
}
int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++) {
partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
}
index = block_id * MAX_K;
for (int i = 0; i < MAX_K; i++) {
id_buf[index + i] = partial.p[i];
val_buf[index + i] = partial.u[i];
}
}
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __restrict x,
const T* __restrict y,
int* __restrict z,
float* __restrict v,
float* output_log_probs,
const bool* finished,
const int* sequence_lengths,
BeamHypotheses beam_hyps,
const int V,
const int K,
const int vocab_size,
const float length_penalty,
const T diversity_rate)
{
int thread_id = threadIdx.x;
int vector_id = blockIdx.x;
// reposition x, y to data for the current vector
x += vector_id * V;
y += vector_id * V;
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ int selected_beams;
__shared__ float old_cum_log_probs[MAX_K];
if (thread_id == 0) {
selected_beams = 0;
}
if (thread_id < K) {
old_cum_log_probs[thread_id] = v[vector_id * K + thread_id];
}
__syncthreads();
if (beam_hyps.num_beams != nullptr) {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + vector_id;
if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0) {
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
}
else if (beam_hyps.num_beams[global_batch_idx] == K) {
return;
}
}
TopK<T, MAX_K> partial;
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -FLT_MAX;
}
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) {
int i = elem_id % K;
T elem = length_penalty == 0.0f ? y[elem_id] :
apply_length_penalty(y[elem_id],
finished[vector_id] ? sequence_lengths[vector_id] :
sequence_lengths[vector_id] + 1,
length_penalty);
elem += diversity_rate * (T)i;
int elem_idx = elem_id; // x[elem_id];
partial.insert(elem, elem_idx);
}
TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);
if (thread_id == 0) {
z += vector_id * K;
v += vector_id * K;
for (int i = 0; i < MAX_K; ++i) {
if (beam_hyps.num_beams != nullptr && x[total.p[i]] % vocab_size == beam_hyps.end_ids[vector_id]) {
// if beam_token does not belong to top num_beams tokens, it should not be added. Refer from
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
if (i >= K) {
// do nothing
}
else {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + vector_id;
const float normed_score = (float)total.u[i];
const int num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of selected candidatet
// is higher than min_normed_score or not. If current score is better, replace worst one
// and update the min_normed_score.
if (num_beam == K) {
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) {
// end the tracing and exist this for loop
selected_beams = K;
break;
}
else {
// find the beam index which's score = min_normed_score, erase it.
for (int j = 0; j < K; j++) {
if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j]
== beam_hyps.min_normed_scores[global_batch_idx]) {
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] = normed_score;
for (int l = 0; l < K; l++) {
beam_hyps.min_normed_scores[global_batch_idx] =
min(beam_hyps.min_normed_scores[global_batch_idx],
beam_hyps.normed_scores[global_batch_idx * (K * 2) + l]);
}
break;
}
}
}
}
const int tgt_id_offset =
((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx)
* (beam_hyps.max_seq_len);
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = beam_hyps.end_ids[vector_id];
if (beam_hyps.log_probs != nullptr) {
beam_hyps.log_probs[tgt_id_offset + beam_hyps.step] =
(float)y[total.p[i]] - old_cum_log_probs[(x[total.p[i]] / vocab_size) % K];
}
int prev_id = (x[total.p[i]] / vocab_size) % K;
for (int j = beam_hyps.step - 1; j >= 0; j--) {
const int src_idx = j * beam_hyps.batch_size * K
+ beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) {
beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx];
}
prev_id = beam_hyps.parent_ids_src[src_idx];
}
const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx] =
min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
beam_hyps.num_beams[global_batch_idx]++;
beam_hyps.cum_log_probs[tgt_beam_idx] = (float)y[total.p[i]];
}
}
else if ((beam_hyps.num_beams != nullptr && i < 2 * K) || (beam_hyps.num_beams == nullptr && i < K)) {
z[selected_beams] = x[total.p[i]];
if (output_log_probs != nullptr) {
output_log_probs[vector_id * K + selected_beams] =
(float)y[total.p[i]] - old_cum_log_probs[(z[selected_beams] / vocab_size) % K];
}
v[selected_beams] = (float)y[total.p[i]];
selected_beams++;
}
__syncthreads();
if (selected_beams >= K) {
break;
}
}
}
if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr) {
if (beam_hyps.num_beams[blockIdx.x] < K) {
beam_hyps.is_done[blockIdx.x] = false;
}
else if (beam_hyps.early_stopping) {
beam_hyps.is_done[blockIdx.x] = true;
}
}
}
struct __align__(8) MD
{
float m;
float d;
};
__device__ __forceinline__ MD reduce_md_op(MD a, MD b)
{
bool a_bigger = (a.m > b.m);
MD bigger_m = a_bigger ? a : b;
MD smaller_m = a_bigger ? b : a;
MD res;
res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.m - bigger_m.m);
res.m = bigger_m.m;
return res;
}
template<typename T, int MAX_K>
struct TopKMD {
MD md;
TopK<T, MAX_K> topk;
};
template<typename T, int MAX_K>
__device__ __forceinline__ TopKMD<T, MAX_K> reduce_topk_md_op(const TopKMD<T, MAX_K>& a, const TopKMD<T, MAX_K>& b)
{
TopKMD<T, MAX_K> res;
res.md = reduce_md_op(a.md, b.md);
res.topk = reduce_topk_op(a.topk, b.topk);
return res;
}
template<typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict x,
const T* __restrict b,
const float* __restrict c,
const bool* __restrict finished,
int* __restrict z,
T* __restrict v,
int V,
int K,
const int* __restrict end_ids)
{
int thread_id = threadIdx.x;
int vector_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// reposition y to data for the current vector
x += vector_id * V;
typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopKMD<float, MAX_K> partial;
bool finish = finished[vector_id];
for (int i = 0; i < MAX_K; ++i) {
partial.topk.p[i] = -1;
partial.topk.u[i] = -MAX_T_VAL;
}
partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F;
if (finish) {
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) {
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
// if (elem_id > THREADBLOCK_SIZE * MAX_K && (elem_id == E)) break;
}
}
else {
for (int elem_id = thread_id; elem_id < V; elem_id += THREADBLOCK_SIZE) {
float elem = x[elem_id] + b[elem_id];
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
}
}
TopKMD<float, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<float, MAX_K>);
if (thread_id == 0) {
z += vector_id * K;
v += vector_id * K;
c += vector_id;
// float d_total_inverse = __fdividef(1.0F, total.md.d);
float d_total_log = logf(total.md.d);
for (int i = 0; i < MAX_K; ++i) {
// float val = __expf(total.topk.u[i] - total.md.m) * d_total_inverse;
float val = total.topk.u[i] - total.md.m - d_total_log;
if (i < K) {
z[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id
v[i] = val + c[0];
}
}
}
}
template<typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__
void beam_online_softmax_topk_stage1_kernel(const T* __restrict x,
const T* __restrict b,
const bool* __restrict finished,
float* __restrict t,
int V,
int K,
const int* __restrict end_ids)
{
int thread_id = threadIdx.x;
int vector_id = blockIdx.x; // batch beam index.
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
// one will have multiple sections per V
const int v_local = (V + gridDim.y - 1) / gridDim.y;
const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local;
section_end = (section_end > V) ? V : section_end;
// reposition x to data for the current vector
x += vector_id * V;
#if TOPK_FP16_STORAGE == 1
typedef cub::BlockReduce<TopKMD<__half, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
#else
typedef cub::BlockReduce<TopKMD<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
#endif
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ float buf_s[PACKED_TOP_KMD_SIZE]; // save intermediate result
#if TOPK_FP16_STORAGE == 1
TopKMD<__half, MAX_K> partial;
#else
TopKMD<T, MAX_K> partial;
#endif
bool finish = finished[vector_id];
for (int i = 0; i < MAX_K; ++i) {
partial.topk.p[i] = -1;
partial.topk.u[i] = -MAX_T_VAL;
}
partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F;
if (finish) {
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) {
float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL;
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
}
}
else {
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) {
T bias = b == nullptr ? (T)0.0f : b[elem_id]; // gpt-2 does not use bias
T elem = x[elem_id] + bias;
MD new_elem{elem, 1.0F};
partial.md = reduce_md_op(partial.md, new_elem);
partial.topk.insert(elem, elem_id);
}
}
#if TOPK_FP16_STORAGE == 1
TopKMD<__half, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<__half, MAX_K>);
#else
TopKMD<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<T, MAX_K>);
#endif
if (thread_id == 0) {
for (int i = 0; i < 2 * K; i++) {
reinterpret_cast<int*>(buf_s)[i] = total.topk.p[i] + vector_id * V; // faster transformer needs absolute id
buf_s[MAX_K + i] = total.topk.u[i];
}
buf_s[2 * MAX_K] = total.md.d;
buf_s[2 * MAX_K + 1] = total.md.m;
}
__syncthreads();
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) {
t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id];
}
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel(
const float* __restrict x, const float* __restrict c, int* __restrict z, T* __restrict v, int K, int parts_per_beam)
{
const int vector_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int PACKED_TOP_KMD_SIZE = 2 * MAX_K + 2;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
extern __shared__ char buf_s_[]; // intermediate result
float* buf_s = reinterpret_cast<float*>(buf_s_);
//__shared__ float buf_s[PACKED_TOP_KMD_SIZE * THREADBLOCK_SIZE]; // intermediate result
typedef cub::BlockReduce<TopKMD<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
x += vector_id * PACKED_TOP_KMD_SIZE * parts_per_beam;
TopKMD<T, MAX_K> partial;
for (int i = 0; i < MAX_K; ++i) {
partial.topk.p[i] = -1;
partial.topk.u[i] = -MAX_T_VAL;
}
partial.md.m = -MAX_T_VAL;
partial.md.d = 0.0F;
// load and unpack into registers through smem
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * parts_per_beam; idx += THREADBLOCK_SIZE) {
buf_s[idx] = x[idx];
}
__syncthreads();
if (threadIdx.x < parts_per_beam) {
float* b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < 2 * K; i++) {
partial.topk.p[i] = reinterpret_cast<int*>(b_s)[i];
partial.topk.u[i] = b_s[MAX_K + i];
}
partial.md.d = b_s[2 * MAX_K];
partial.md.m = b_s[2 * MAX_K + 1];
}
__syncthreads();
TopKMD<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_md_op<T, MAX_K>);
if (thread_id == 0) {
z += vector_id * 2 * K;
v += vector_id * 2 * K;
c += vector_id;
float d_total_log = logf(total.md.d);
for (int i = 0; i < MAX_K; ++i) {
float val = (float)total.topk.u[i] - total.md.m - d_total_log;
if (i < 2 * K) {
z[i] = total.topk.p[i];
v[i] = (float)val + (float)c[0];
}
}
}
}
template<typename T, int MAX_K>
void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage,
const float* cum_log_probs,
int* ids,
T* vals,
int batch_size,
int beam_width,
int parts_per_beam,
cudaStream_t stream)
{
// might rewrite beam_online_softmax_topk_stage2_kernel no to depend on constant block size
// in oreder to reduce compilation time
int smem_stage2_size = parts_per_beam * (2 * MAX_K + 2) * sizeof(float);
if (parts_per_beam <= 32) {
beam_online_softmax_topk_stage2_kernel<T, MAX_K, 32><<<batch_size * beam_width, 32, smem_stage2_size, stream>>>(
temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam);
return;
}
if (parts_per_beam <= 64) {
beam_online_softmax_topk_stage2_kernel<T, MAX_K, 64><<<batch_size * beam_width, 64, smem_stage2_size, stream>>>(
temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam);
return;
}
if (parts_per_beam <= 128) {
beam_online_softmax_topk_stage2_kernel<T, MAX_K, 128>
<<<batch_size * beam_width, 128, smem_stage2_size, stream>>>(
temp_storage, cum_log_probs, ids, vals, beam_width, parts_per_beam);
return;
}
assert(0);
}
template<typename T, int MAX_K>
void topK_softMax_kernelLauncher(const T* log_probs,
const T* bias,
const bool* finished,
const int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
int* ids,
void* temp_storage,
const int temp_storage_size,
BeamHypotheses* beam_hyps,
const int batch_size,
const int beam_width,
const int vocab_size,
const int* end_ids,
T diversity_rate,
const float length_penalty,
cudaStream_t stream)
{
const int items_per_thread = 1;
const int block_sz = (MAX_K < 16) ? (MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128 : 64;
// const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE;
assert(temp_storage_size % 2 == 0);
assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2);
// Beam search needs the sequence lengths of beams to apply length penalty.
assert(length_penalty == 0.0f || sequence_lengths != nullptr);
const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4;
int* topk_tmp_id_buf = reinterpret_cast<int*>(temp_storage);
T* topk_tmp_val_buf = reinterpret_cast<T*>(topk_tmp_id_buf + topk_buf_offset);
float* tmp_buffer = reinterpret_cast<float*>(topk_tmp_val_buf + topk_buf_offset);
#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX
int voc_parts = 4;
if (batch_size * beam_width < 256) {
// Volta has 80 SMs, so we aim for three waves
voc_parts = (240 + batch_size * beam_width - 1) / (batch_size * beam_width);
voc_parts = std::min(128, voc_parts); // we implement up to 128
}
dim3 grid(batch_size * beam_width, voc_parts);
cudaFuncSetAttribute(beam_online_softmax_topk_stage1_kernel<T, items_per_thread, 2 * MAX_K, block_sz>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxL1);
beam_online_softmax_topk_stage1_kernel<T, items_per_thread, 2 * MAX_K, block_sz>
<<<grid, block_sz, 0, stream>>>(log_probs, bias, finished, tmp_buffer, vocab_size, beam_width, end_ids);
sync_check_cuda_error();
#endif
if (beam_width > 1) {
#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX
beam_online_softmax_topk_stage2_kernelLauncher<T, 2 * MAX_K>(
tmp_buffer, cum_log_probs, topk_tmp_id_buf, topk_tmp_val_buf, batch_size, beam_width, voc_parts, stream);
sync_check_cuda_error();
#else
beam_online_softmax_topk_kernel<T, items_per_thread, MAX_K, block_sz>
<<<batch_size * beam_width, block_sz, 0, stream>>>(log_probs,
bias,
cum_log_probs,
finished,
topk_tmp_id_buf,
topk_tmp_val_buf,
vocab_size,
beam_width,
end_ids);
#endif
#if 0
// wrong result with diversity_rate != 0.f
batch_topK_kernel<T, MAX_K, 32><<<batch_size, 32, 0, stream>>>
(topk_tmp_id_buf, topk_tmp_val_buf, ids, cum_log_probs);
#else
// We need 2*MAX_K candidates because at most k candidates are finished, and we
// will not put them into next iteration
batch_topk_kernel<T, MAX_K * 2, 32><<<batch_size, 32, 0, stream>>>(topk_tmp_id_buf,
topk_tmp_val_buf,
ids,
cum_log_probs,
output_log_probs,
finished,
sequence_lengths,
*beam_hyps,
beam_width * beam_width * 2,
beam_width,
vocab_size,
length_penalty,
diversity_rate);
sync_check_cuda_error();
#endif
}
else {
FT_CHECK(false);
#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX
beam_online_softmax_topk_stage2_kernelLauncher<float, MAX_K>(
tmp_buffer, cum_log_probs, ids, cum_log_probs, batch_size, beam_width, voc_parts, stream);
#else
beam_online_softmax_topk_kernel<T, items_per_thread, MAX_K, block_sz>
<<<batch_size * beam_width, block_sz, 0, stream>>>(
log_probs, bias, cum_log_probs, finished, ids, cum_log_probs, vocab_size, beam_width, end_ids);
#endif
}
}
#define CASE_K(K, MAX_K) \
case K ... MAX_K: \
topK_softMax_kernelLauncher<T, MAX_K>(log_probs, \
bias, \
finished, \
sequence_lengths, \
cum_log_probs, \
output_log_probs, \
ids, \
temp_storage, \
temp_storage_size, \
beam_hyps, \
batch_size, \
beam_width, \
vocab_size, \
end_ids, \
diversity_rate, \
length_penalty, \
stream); \
break;
template<typename T>
void invokeTopkSoftMax(const T* log_probs,
const T* bias,
const bool* finished,
const int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
int* ids,
void* temp_storage,
const int temp_storage_size,
BeamHypotheses* beam_hyps,
const int batch_size,
const int beam_width,
const int vocab_size,
const int* end_ids,
const float diversity_rate,
const float length_penalty,
cudaStream_t stream)
{
switch (beam_width) {
CASE_K(1, 4);
CASE_K(5, 8);
CASE_K(9, 16);
CASE_K(17, 32);
CASE_K(33, 64);
default:
throw std::runtime_error(fmtstr("Topk kernel of beam search does not support beam_width=%d", beam_width));
}
}
#undef CASE_K
template void invokeTopkSoftMax<float>(const float* log_probs,
const float* bias,
const bool* finished,
const int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
int* ids,
void* tmp_storage,
const int temp_storage_size,
BeamHypotheses* beam_hyps,
const int batch_size,
const int beam_width,
const int vocab_size,
const int* end_ids,
const float diversity_rate,
const float length_penalty,
cudaStream_t stream);
template void invokeTopkSoftMax<half>(const half* log_probs,
const half* bias,
const bool* finished,
const int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
int* ids,
void* tmp_storage,
const int temp_storage_size,
BeamHypotheses* beam_hyps,
const int batch_size,
const int beam_width,
const int vocab_size,
const int* end_ids,
const float diversity_rate,
const float length_penalty,
cudaStream_t stream);
} // end of namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/kernels/beam_search_topk_kernels.h"
namespace fastertransformer {
template<typename T>
void invokeTopkSoftMax(const T* log_probs,
const T* bias,
const bool* finished,
const int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
int* ids,
void* tmp_storage,
const int temp_storage_size,
BeamHypotheses* beam_hyps,
const int batch_size,
const int beam_width,
const int vocab_size,
const int* end_ids,
const float diversity_rate,
const float length_penalty,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <unordered_map>
#include "src/fastertransformer/utils/string_utils.h"
namespace fastertransformer {
enum class RepetitionPenaltyType {
Additive, // the presence penalty
Multiplicative, // the repetition penalty
None // No repetition penalty.
};
inline float getDefaultPenaltyValue(RepetitionPenaltyType penalty_type)
{
switch (penalty_type) {
case RepetitionPenaltyType::Additive:
return 0.0f;
case RepetitionPenaltyType::Multiplicative:
return 1.0f;
default:
break;
}
return 0.0f;
}
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda_fp16.h>
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <float.h>
#include <type_traits>
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
namespace cg = cooperative_groups;
namespace fastertransformer {
template <int VPT>
struct BytesToType;
template <>
struct BytesToType<2>
{
using type = uint16_t;
};
template <>
struct BytesToType<4>
{
using type = uint32_t;
};
template <>
struct BytesToType<8>
{
using type = uint64_t;
};
template <>
struct BytesToType<16>
{
using type = float4;
};
template <int Bytes>
__device__ inline void copy(const void* local, void* data)
{
using T = typename BytesToType<Bytes>::type;
const T* in = static_cast<const T*>(local);
T* out = static_cast<T*>(data);
*out = *in;
}
static const float HALF_FLT_MAX = 65504.F;
#define FINAL_MASK 0xffffffff
template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template<typename T>
__inline__ __device__ T warpReduceMax(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
/* Calculate the maximum of all elements in a block */
template<typename T>
__inline__ __device__ T blockAllReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f;
val = warpReduceMax(val);
return val;
}
template<typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template<typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val)
{
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template<typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T* val)
{
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T)(0.0f);
}
template<typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T* val)
{
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) // record in-warp maxx by warp Idx
{
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T)0.0f;
}
template<int NUM>
__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm)
{
cg::thread_block cta = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
const int tid = cta.thread_rank();
const int blockz = blockDim.x;
for (int i = 0; i < NUM; i++) {
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus<float>());
#else
// TODO Add implementation here
if (threadIdx.x == 0 && blockIdx.x == 0) {
printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n");
assert(false);
}
#endif
}
cg::sync(cta);
if (tid == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
float beta = 0.0f;
for (int j = 0; j < blockz; j += 32) {
beta += cgBlockReduceSumElements_shm[i * blockz + j];
}
element_list[i] = beta;
}
}
}
template<typename T, int MAX_K>
struct TopK {
int p[MAX_K];
T u[MAX_K];
__device__ __forceinline__ void insert(T elem, int elem_id)
{
if (elem > u[MAX_K - 1] || (p[MAX_K - 1] == -1) || ((elem == u[MAX_K - 1]) && (elem_id < p[MAX_K - 1])))
// if (elem > u[MAX_K-1] || ((elem == u[MAX_K-1]) && (elem_id < p[MAX_K-1])))
{
u[MAX_K - 1] = elem;
p[MAX_K - 1] = elem_id;
}
for (int k = MAX_K - 2; k >= 0; --k) {
if ((u[k + 1] > u[k]) || (p[k] == -1) || ((u[k + 1] == u[k]) && (p[k + 1] < p[k])))
// if ((u[k+1] > u[k]) || ((u[k+1] == u[k])&&(p[k+1] < p[k])))
{
T u2 = u[k];
int p2 = p[k];
u[k] = u[k + 1];
p[k] = p[k + 1];
u[k + 1] = u2;
p[k + 1] = p2;
}
}
}
__device__ __forceinline__ void init()
{
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
for (int i = 0; i < MAX_K; i++) {
p[i] = -1;
u[i] = -MAX_T_VAL;
}
}
};
template<typename T, int MAX_K>
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(const TopK<T, MAX_K>& a, const TopK<T, MAX_K>& b)
{
TopK<T, MAX_K> res = a;
for (int i = 0; i < MAX_K; ++i)
res.insert(b.u[i], b.p[i]);
return res;
}
template<typename T>
struct TopK_2 {
int p = -1;
T u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
__device__ __forceinline__ void insert(T elem, int elem_id)
{
if (elem > u) {
u = elem;
p = elem_id;
}
}
__device__ __forceinline__ void init()
{
u = -((std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX);
p = -1;
}
};
template<typename T>
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(const TopK_2<T>& a, const TopK_2<T>& b)
{
return a.u > b.u ? a : b;
}
template<typename T>
__device__ __forceinline__ T clamp_inf_for_half(const float input)
{
return input;
}
template<>
__device__ __forceinline__ half clamp_inf_for_half(const float input)
{
// clamp inf values to enable fp16 training
return input > 0.0f ? (half)min(input, HALF_FLT_MAX - 1000) : (half)max(input, -HALF_FLT_MAX + 1000);
}
#ifdef ENABLE_BF16
template<>
__device__ __forceinline__ __nv_bfloat16 clamp_inf_for_half(const float input)
{
return __float2bfloat16(input);
}
#endif
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <float.h>
#include "src/fastertransformer/kernels/sampling_penalty_kernels.h"
namespace fastertransformer {
// TODO Add half2 implementation
template<typename T>
__global__ void applyTemperaturePenalty(T* logits,
const T* bias,
const float temperature_inverse,
const int m,
const int vocab_size,
const int vocab_size_padd)
{
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < m * vocab_size_padd;
index += blockDim.x * gridDim.x) {
T bias_val = bias == nullptr ? (T)(0.0f) : bias[index % vocab_size_padd];
if (index % vocab_size_padd < vocab_size) {
logits[index] = (logits[index] + bias_val) * (T)temperature_inverse;
}
else {
logits[index] = -MAX_T_VAL;
}
}
}
template<>
__global__ void applyTemperaturePenalty(half2* logits,
const half2* bias,
const float temperature_inverse,
const int batch_size,
const int vocab_size,
const int vocab_size_padded)
{
assert(vocab_size % 2 == 0);
assert(vocab_size_padded % 2 == 0);
const half2 mask_val = __float2half2_rn(-65504.0f);
const half2 temp_inv = __float2half2_rn(temperature_inverse);
const int half_vocab_size = vocab_size / 2;
const int half_vocab_size_padded = vocab_size_padded / 2;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded;
index += blockDim.x * gridDim.x) {
int vocab_idx = index % half_vocab_size_padded;
half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val;
if (vocab_idx < half_vocab_size) {
if (bias != nullptr) {
logit = __hadd2(logit, bias[vocab_idx]);
}
logits[index] = __hmul2(logit, temp_inv);
}
}
}
template<typename T>
void invokeApplyTemperaturePenalty(T* logits,
const T* bias,
const float temperature,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream)
{
dim3 block(min(vocab_size_padd, 1024));
dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536));
const T temperature_inverse = (T)(1.f / (temperature + 1e-6f));
if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) {
applyTemperaturePenalty<<<grid, block, 0, stream>>>(reinterpret_cast<half2*>(logits),
reinterpret_cast<const half2*>(bias),
temperature_inverse,
batch_size,
vocab_size,
vocab_size_padd);
}
else {
applyTemperaturePenalty<T>
<<<grid, block, 0, stream>>>(logits, bias, temperature_inverse, batch_size, vocab_size, vocab_size_padd);
}
}
template void invokeApplyTemperaturePenalty(float* logits,
const float* bias,
const float temperature,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream);
template void invokeApplyTemperaturePenalty(half* logits,
const half* bias,
const float temperature,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream);
template<typename T>
__global__ void batchApplyTemperaturePenalty(T* logits,
const T* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
const int vocab_size_padd)
{
// TODO: Add macro or device function to get MAX_T_VAL.
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? 65504.F : FLT_MAX;
extern __shared__ float inv_temperatures[];
if (threadIdx.x < batch_size) {
inv_temperatures[threadIdx.x] = 1.0f / (temperatures[threadIdx.x] + 1e-6f);
}
__syncthreads();
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * vocab_size_padd;
index += blockDim.x * gridDim.x) {
int batch_idx = index / vocab_size_padd;
int vocab_idx = index % vocab_size_padd;
T logit = (vocab_idx < vocab_size) ? logits[index] : -MAX_T_VAL;
if (vocab_idx < vocab_size) {
if (bias != nullptr) {
logit += bias[vocab_idx];
}
logit *= inv_temperatures[batch_idx];
}
logits[index] = logit;
}
}
__global__ void batchApplyTemperaturePenalty_h2(half2* logits,
const half2* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
const int vocab_size_padded)
{
assert(vocab_size % 2 == 0);
assert(vocab_size_padded % 2 == 0);
extern __shared__ half2 h2_inv_temperatures[];
if (threadIdx.x < batch_size) {
h2_inv_temperatures[threadIdx.x] = __float2half2_rn(1.f / (temperatures[threadIdx.x] + 1e-6f));
}
__syncthreads();
const half2 mask_val = __float2half2_rn(-65504.0f);
const int half_vocab_size = vocab_size / 2;
const int half_vocab_size_padded = vocab_size_padded / 2;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * half_vocab_size_padded;
index += blockDim.x * gridDim.x) {
int batch_idx = index / half_vocab_size_padded;
int vocab_idx = index % half_vocab_size_padded;
half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val;
if (vocab_idx < half_vocab_size) {
if (bias != nullptr) {
logit = __hadd2(logit, bias[vocab_idx]);
}
logits[index] = __hmul2(logit, h2_inv_temperatures[batch_idx]);
}
}
}
template<typename T>
void invokeBatchApplyTemperaturePenalty(T* logits,
const T* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream)
{
dim3 block(min(vocab_size_padd, 1024));
dim3 grid(min(batch_size * vocab_size_padd / block.x, 65536));
if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padd % 2 == 0) {
size_t smem_size = sizeof(half2) * batch_size;
batchApplyTemperaturePenalty_h2<<<grid, block, smem_size, stream>>>(reinterpret_cast<half2*>(logits),
reinterpret_cast<const half2*>(bias),
temperatures,
batch_size,
vocab_size,
vocab_size_padd);
}
else {
size_t smem_size = sizeof(float) * batch_size;
batchApplyTemperaturePenalty<T>
<<<grid, block, smem_size, stream>>>(logits, bias, temperatures, batch_size, vocab_size, vocab_size_padd);
}
}
template void invokeBatchApplyTemperaturePenalty(float* logits,
const float* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream);
template void invokeBatchApplyTemperaturePenalty(half* logits,
const half* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream);
template<typename T, RepetitionPenaltyType penalty_type>
__global__ void applyRepetitionPenalty(T* logits,
const float penalty,
const int* start_ids,
int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int vocab_size_padd,
const int* input_lengths,
const int max_input_len,
const int step)
{
extern __shared__ float penalty_logits[];
int* penalty_indices = (int*)(penalty_logits + step);
logits = logits + blockIdx.x * vocab_size_padd;
const int input_length = input_lengths != nullptr ? input_lengths[blockIdx.x] : max_input_len;
for (int index = threadIdx.x; index < step; index += blockDim.x) {
if (index >= input_length && index < max_input_len) {
continue;
}
// output_ids shape: (input_len + output_len, batch_size)
int penalty_index = output_ids[index * batch_size + blockIdx.x];
if (penalty_index >= vocab_size) {
continue;
}
penalty_indices[index] = penalty_index;
float logit = (float)logits[penalty_index];
if (penalty_type == RepetitionPenaltyType::Additive) {
penalty_logits[index] = logit - penalty;
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty;
}
else if (penalty_type == RepetitionPenaltyType::None) {
penalty_logits[index] = logit;
}
else {
// Unsupported type
assert(false);
}
}
if (blockDim.x > 32) {
__syncthreads();
}
for (int index = threadIdx.x; index < step; index += blockDim.x) {
if (index >= input_length && index < max_input_len) {
continue;
}
// output_ids shape: (input_len + output_len, batch_size)
if (penalty_indices[index] >= vocab_size) {
continue;
}
logits[penalty_indices[index]] = penalty_logits[index];
}
}
template<typename T>
void invokeApplyRepetitionPenalty(T* logits,
const float penalty,
const int* start_ids,
int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int vocab_size_padd,
const int* input_lengths,
const int max_input_len,
const int step,
const RepetitionPenaltyType penalty_type,
cudaStream_t stream)
{
dim3 block(min(step, 1024));
dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));
if (penalty_type == RepetitionPenaltyType::Additive) {
applyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(logits,
penalty,
start_ids,
output_ids,
batch_size,
local_batch_size,
vocab_size,
vocab_size_padd,
input_lengths,
max_input_len,
step);
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
applyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>
<<<grid, block, smem_size, stream>>>(logits,
penalty,
start_ids,
output_ids,
batch_size,
local_batch_size,
vocab_size,
vocab_size_padd,
input_lengths,
max_input_len,
step);
}
else if (penalty_type == RepetitionPenaltyType::None) {
// do nothing
}
}
template void invokeApplyRepetitionPenalty(float* logits,
const float penalty,
const int* start_ids,
int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int vocab_size_padd,
const int* input_lengths,
const int max_input_len,
const int step,
const RepetitionPenaltyType penalty_type,
cudaStream_t stream);
template void invokeApplyRepetitionPenalty(half* logits,
const float penalty,
const int* start_ids,
int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int vocab_size_padd,
const int* input_lengths,
const int max_input_len,
const int step,
const RepetitionPenaltyType penalty_type,
cudaStream_t stream);
template<typename T, RepetitionPenaltyType penalty_type>
__global__ void batchApplyRepetitionPenalty(T* logits,
const float* penalties,
const int* output_ids,
const int batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step)
{
extern __shared__ float penalty_logits[];
int* penalty_indices = (int*)(penalty_logits + step);
const int batch_idx = blockIdx.x;
const float penalty = penalties[batch_idx];
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
logits += batch_idx * vocab_size;
// Phase 1. Find indices to penalize and keep the penalized values.
// A vocab id can appear multiple times but should be penalized once.
for (int index = threadIdx.x; index < step; index += blockDim.x) {
// Skip the padding tokens in input sequences.
if (index >= input_length && index < max_input_length) {
continue;
}
// output_ids shape: (input_len + output_len, batch_size)
int penalty_index = output_ids[index * batch_size + batch_idx];
assert(penalty_index < vocab_size);
penalty_indices[index] = penalty_index;
float logit = (float)logits[penalty_index];
if (penalty_type == RepetitionPenaltyType::Additive) {
penalty_logits[index] = logit - penalty;
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
penalty_logits[index] = logit < 0.0f ? logit * penalty : logit / penalty;
}
else if (penalty_type == RepetitionPenaltyType::None) {
penalty_logits[index] = logit;
}
else {
// Unsupported type
assert(false);
}
}
if (blockDim.x > 32) {
__syncthreads();
}
// Phase 2. Replace a logit value by the penalized one.
for (int index = threadIdx.x; index < step; index += blockDim.x) {
// Skip the padding tokens in input sequences.
if (index >= input_length && index < max_input_length) {
continue;
}
logits[penalty_indices[index]] = penalty_logits[index];
}
}
template<typename T>
void invokeBatchApplyRepetitionPenalty(T* logits,
const float* penalties,
const int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step,
RepetitionPenaltyType penalty_type,
cudaStream_t stream)
{
// Inputs
// logits [local_batch_size, vocab_size] : logit values.
// penalties [local_batch_size] : repetition penalty factors.
// output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size).
// input_lengths [local_batch_size], input lengths (optional).
// Padding tokens at [input_length, max_input_length) of input will not be penalized.
dim3 block(min(step, 1024));
dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));
if (penalty_type == RepetitionPenaltyType::Additive) {
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
else if (penalty_type == RepetitionPenaltyType::None) {
// do nothing
}
}
template void invokeBatchApplyRepetitionPenalty(float* logits,
const float* penalties,
const int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step,
RepetitionPenaltyType penalty_type,
cudaStream_t stream);
template void invokeBatchApplyRepetitionPenalty(half* logits,
const float* penalties,
const int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step,
RepetitionPenaltyType penalty_type,
cudaStream_t stream);
template<typename T>
__global__ void batchApplyMinLengthPenalty(T* logits,
const int* min_lengths,
const int* end_ids,
const int* sequence_lengths,
const int max_input_length,
const int vocab_size_padded)
{
int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index
// We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1,
// which is equal to the length of k/v caches.
if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) {
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
logits[bid * vocab_size_padded + end_ids[bid]] = mask_val;
}
}
template<typename T>
void invokeMinLengthPenalty(T* logits,
const int* min_lengths,
const int* end_ids,
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream)
{
const int block_size = min(batch_size, 1024);
const int grid_size = (batch_size + block_size - 1) / block_size;
batchApplyMinLengthPenalty<<<grid_size, block_size, 0, stream>>>(
logits, min_lengths, end_ids, sequnece_lengths, max_input_length, vocab_size_padded);
}
template void invokeMinLengthPenalty(float* logits,
const int* min_lengths,
const int* end_ids,
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream);
template void invokeMinLengthPenalty(half* logits,
const int* min_lengths,
const int* end_ids,
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include "src/fastertransformer/kernels/penalty_types.h"
#include "src/fastertransformer/utils/cuda_utils.h"
namespace fastertransformer {
template<typename T>
void invokeApplyRepetitionPenalty(T* logits,
const float penalty,
const int* start_ids,
int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int vocab_size_padd,
const int* input_lengths,
const int max_input_len,
const int step,
const RepetitionPenaltyType penalty_type,
cudaStream_t stream);
template<typename T>
void invokeBatchApplyRepetitionPenalty(T* logits,
const float* penalties,
const int* output_ids,
const int batch_size,
const int local_batch_size,
const int vocab_size,
const int* input_lengths,
const int max_input_length,
const int step,
const RepetitionPenaltyType penalty_type,
cudaStream_t stream);
template<typename T>
void invokeApplyTemperaturePenalty(T* logits,
const T* bias,
const float temperature,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream);
template<typename T>
void invokeBatchApplyTemperaturePenalty(T* logits,
const T* bias,
const float* temperatures,
const int batch_size,
const int vocab_size,
const int vocab_size_padd,
cudaStream_t stream);
template<typename T>
void invokeMinLengthPenalty(T* logits,
const int* min_lengths,
const int* end_ids,
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdexcept>
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/kernels/sampling_topk_kernels.h"
namespace fastertransformer {
__global__ void curandInitialize(curandState_t* state, const int size, const unsigned long long random_seed)
{
if (threadIdx.x + blockIdx.x * blockDim.x < size) {
curand_init(random_seed, 0, 0, &state[blockIdx.x * blockDim.x + threadIdx.x]);
}
}
void invokeCurandInitialize(curandState_t* state,
const size_t batch_size,
const unsigned long long random_seed,
cudaStream_t stream)
{
dim3 block(256);
dim3 grid((int)(ceil(batch_size * 1.0 / 256)));
curandInitialize<<<grid, block, 0, stream>>>(state, batch_size, random_seed);
}
__global__ void curandBatchInitialize(curandState_t* states, const int size, const unsigned long long* random_seeds)
{
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < size) {
curand_init(random_seeds[idx], 0, 0, &states[idx]);
}
}
void invokeCurandBatchInitialize(curandState_t* states,
const size_t batch_size,
const unsigned long long* random_seeds,
cudaStream_t stream)
{
dim3 block(256);
dim3 grid((int)(ceil(batch_size * 1.0 / 256)));
curandBatchInitialize<<<grid, block, 0, stream>>>(states, batch_size, random_seeds);
}
template<typename T>
__global__ void addBiasEndMask(T* logits,
const T* bias,
const int* end_ids,
const bool* finished,
const int vocab_size,
const int vocab_size_padded)
{
int bid = blockIdx.x;
bool finish = finished != nullptr ? finished[bid] : false;
int offset = bid * vocab_size_padded;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
for (int tid = threadIdx.x; tid < vocab_size_padded; tid += blockDim.x) {
if (tid >= vocab_size) {
logits[offset + tid] = -MAX_T_VAL;
}
else if (finish) {
logits[offset + tid] = (tid == end_ids[bid]) ? MAX_T_VAL : -MAX_T_VAL;
}
else {
if (bias != nullptr) {
logits[offset + tid] += bias[tid];
}
}
}
}
template<typename T>
void invokeAddBiasEndMask(T* logits,
const T* bias,
const int* end_ids,
const bool* finished,
const int batch_size,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream)
{
dim3 grid(batch_size);
dim3 block(min(vocab_size_padded, 1024));
/*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */
addBiasEndMask<<<grid, block, 0, stream>>>(logits, bias, end_ids, finished, vocab_size, vocab_size_padded);
}
template void invokeAddBiasEndMask(float* logits,
const float* bias,
const int* end_ids,
const bool* finished,
const int batch_size,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
template void invokeAddBiasEndMask(half* logits,
const half* bias,
const int* end_ids,
const bool* finished,
const int batch_size,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
template<typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topk_stage1(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int max_top_k,
const int* top_ks,
const int vocab_size,
const int* end_ids,
const bool* skip_decode)
{
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int batch_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k; // batch_id = batch index
const int tmp_log_buf_index = batch_id * vocab_size;
const int tmp_topk_buf_index = batch_id * BLOCKS_PER_BEAM_ * max_top_k + block_lane * k;
TopK_2<T> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
if (finished != nullptr && finished[batch_id] == true) {
if (tid < k) {
const int index = tmp_topk_buf_index + tid;
if (block_lane == 0 && tid == 0) {
const int end_id = end_ids[batch_id];
topk_tmp_id_buf[index] = tmp_log_buf_index + end_id;
topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id];
}
else {
topk_tmp_id_buf[index] = -1;
topk_tmp_val_buf[index] = -MAX_T_VAL;
}
}
return;
}
for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size;
elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) {
int index = elem_id + tmp_log_buf_index;
tmp_log_probs[index] = log_probs[index];
}
for (int ite = 0; ite < k; ite++) {
partial.init();
#pragma unroll
for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size;
elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) {
int index = elem_id + tmp_log_buf_index;
partial.insert(tmp_log_probs[index], index);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
if (tid == 0) {
const int index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p;
topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL;
}
__syncthreads();
}
}
template<typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topk_stage2_sampling(const int* __restrict topk_tmp_id_buf,
T* topk_tmp_val_buf,
int* ids,
int* sequence_length,
bool* finished,
float* cum_log_probs,
float* output_log_probs,
const int max_top_k,
const int* top_ks,
const float top_p,
const float* top_ps,
curandState_t* curandstate,
const int* end_ids,
const int vocab_size,
const bool* skip_decode)
{
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
const int k = (top_ks != nullptr) ? top_ks[batch_id] : max_top_k;
const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;
const int size = k * BLOCKS_PER_BEAM_;
const int stride = max_top_k * BLOCKS_PER_BEAM_;
typedef cub::BlockReduce<TopK_2<float>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
extern __shared__ char array[];
__shared__ float rand_num;
__shared__ float s_sum;
__shared__ float s_max;
T* s_val = topk_tmp_val_buf + batch_id * stride;
int* s_id = reinterpret_cast<int*>(array);
if (tid == 0) {
s_sum = 0.0f;
}
TopK_2<float> partial;
if (finished != nullptr && finished[batch_id] == true) {
ids[batch_id] = end_ids[batch_id];
return;
}
float* s_val2 = reinterpret_cast<float*>(s_id + k);
for (int ite = 0; ite < k; ite++) {
partial.init();
#pragma unroll
for (int i = tid; i < size; i += BLOCK_SIZE_) {
partial.insert((float)s_val[i], i);
}
TopK_2<float> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<float>);
if (tid == 0) {
if (ite == 0) {
s_max = total.u;
}
s_id[ite] = total.p;
s_val[total.p] = -MAX_T_VAL;
// when cum_log_probs are computed, topk_tmp_val_buf (logits_buf_) are already pre-processed by
// softmax_kernel
if (cum_log_probs == nullptr && output_log_probs == nullptr) {
total.u = __expf(total.u - s_max);
}
s_val2[ite] = total.u;
s_sum += total.u;
}
__syncthreads();
}
if (tid == 0) {
rand_num = (float)curand_uniform(curandstate + blockIdx.x) * prob_threshold * s_sum;
for (int i = 0; i < k; i++) {
float exp_logit = s_val2[i];
rand_num = rand_num - exp_logit;
if (rand_num <= 0.0f || i == k - 1) {
ids[batch_id] = topk_tmp_id_buf[batch_id * stride + s_id[i]] % vocab_size;
if (cum_log_probs != nullptr || output_log_probs != nullptr) {
float log_prob = logf(exp_logit);
if (cum_log_probs != nullptr) {
cum_log_probs[batch_id] += log_prob;
}
if (output_log_probs != nullptr) {
// 'output_log_probs' is the probability induced by the top-k sampling.
// We normalize the probability 'exp_logit' of the selected token by
// the probability 's_sum' of a set of top-k tokens, meaning the log_prob
// is the probability of the selected token, conditioned on the event that
// it is selected, i.e.,
// log_prob = log P(i | i is in top-k) = log(exp_logit / s_sum).
output_log_probs[batch_id] = log_prob - logf(s_sum);
}
}
break;
}
}
if (sequence_length != nullptr && finished != nullptr) {
sequence_length[batch_id] = finished[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished[batch_id] = ids[batch_id] == end_ids[batch_id] ? true : false;
}
}
}
#define CASE_K(K_MIN, K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \
case K_MIN ... K_MAX: \
topk_stage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_> \
<<<batch_size * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
max_top_k, \
top_ks, \
vocab_size, \
end_ids, \
skip_decode); \
topk_stage2_sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
<<<batch_size, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topk_tmp_id_buf, \
topk_tmp_val_buf, \
ids, \
sequence_length, \
finished, \
cum_log_probs, \
output_log_probs, \
max_top_k, \
top_ks, \
top_p, \
top_ps, \
curandstate, \
end_ids, \
vocab_size, \
skip_decode); \
break;
template<typename T>
void invokeBatchTopKSampling(void* workspace,
size_t& workspace_size,
const T* log_probs,
int* ids,
int* sequence_length,
bool* finished,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int max_top_k,
const int* top_ks,
const float top_p,
const float* top_ps,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode)
{
// Not allow an ambiguous inputs top_p and top_ps.
assert(top_p == 1.0f || top_ps == nullptr);
const int vocab_size = vocab_size_padded;
const int max_block_per_beam = 8;
int temp_log_probs_buf_size = batch_size * vocab_size; // type float
int topk_tmp_ids_buf_size = batch_size * max_top_k * max_block_per_beam; // type int
int topk_tmp_val_buf_size = batch_size * max_top_k * max_block_per_beam; // type float
// prevent memory misaligned address
temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;
if (workspace == nullptr) {
workspace_size = sizeof(T) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size
+ sizeof(T) * topk_tmp_val_buf_size;
return;
}
T* temp_log_probs = (T*)workspace;
int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size);
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
switch (max_top_k) {
CASE_K(1, 16, 128, 128, 8);
CASE_K(17, 32, 256, 128, 8);
CASE_K(33, 64, 256, 256, 8);
CASE_K(65, 1024, 256, 256, 8);
default:
throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", max_top_k));
}
}
#undef CASE_K
template void invokeBatchTopKSampling(void* workspace,
size_t& workspace_size,
const float* log_probs,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int max_top_k,
const int* top_ks,
const float top_p,
const float* top_ps,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template void invokeBatchTopKSampling(void* workspace,
size_t& workspace_size,
const half* log_probs,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int max_top_k,
const int* top_ks,
const float top_p,
const float* top_ps,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template<typename T>
void invokeTopKSampling(void* workspace,
size_t& workspace_size,
const T* log_probs,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode)
{
invokeBatchTopKSampling(workspace,
workspace_size,
log_probs,
ids,
sequence_length,
finished_buf,
cum_log_probs,
output_log_probs,
curandstate,
top_k,
nullptr,
top_p,
nullptr,
vocab_size_padded,
end_ids,
stream,
batch_size,
skip_decode);
}
template void invokeTopKSampling(void* workspace,
size_t& workspace_size,
const float* log_probs,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template void invokeTopKSampling(void* workspace,
size_t& workspace_size,
const half* log_probs,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template<typename T>
void invokeTopKTopPSampling(void* workspace,
size_t& workspace_size,
int* output_ids,
const T* logits,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int batch_size,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream)
{
// invokeTopKTopPSampling will be deprecated. Please use invokeTopKSampling instead.
invokeTopKSampling(workspace,
workspace_size,
logits,
output_ids,
sequence_length,
finished_buf,
cum_log_probs,
output_log_probs,
curandstate,
top_k,
top_p,
vocab_size_padded,
end_ids,
stream,
batch_size,
nullptr);
}
template void invokeTopKTopPSampling(void* workspace,
size_t& workspace_size,
int* output_ids,
const float* logits,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int batch_size,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream);
template void invokeTopKTopPSampling(void* workspace,
size_t& workspace_size,
int* output_ids,
const half* logits,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int batch_size,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/utils/logger.h"
#include <curand_kernel.h>
namespace fastertransformer {
template<typename T>
void invokeTopKSampling(void* workspace,
size_t& workspace_size,
const T* log_probs,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
template<typename T>
void invokeBatchTopKSampling(void* workspace,
size_t& workspace_size,
const T* log_probs,
int* ids,
int* sequence_length,
bool* finished,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int max_top_k,
const int* top_ks,
const float top_p,
const float* top_ps,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream,
const int batch_size,
const bool* skip_decode);
void invokeCurandInitialize(curandState_t* state,
const size_t batch_size,
unsigned long long random_seed,
cudaStream_t stream);
void invokeCurandBatchInitialize(curandState_t* states,
const size_t batch_size,
const unsigned long long* random_seeds,
cudaStream_t stream);
template<typename T>
void invokeAddBiasEndMask(T* logits,
const T* bias,
const int* end_ids,
const bool* finished,
const int batch_size,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
template<typename T>
void invokeTopKTopPSampling(void* workspace,
size_t& workspace_size,
int* output_ids,
const T* logits,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
curandState_t* curandstate,
const int batch_size,
const int top_k,
const float top_p,
const int vocab_size_padded,
const int* end_ids,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/kernels/sampling_topp_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
constexpr int ENABLE_SINGLE_PASS_TOP_P = 0;
constexpr float SINGLE_PASS_THRESHOLD = 0.9;
namespace fastertransformer {
namespace segmented_topp_impl {
template<int HALF_ELEMENTS_PER_WARP_LOAD>
using Copy_half_t = typename std::conditional<
HALF_ELEMENTS_PER_WARP_LOAD == 32,
half,
typename std::conditional<HALF_ELEMENTS_PER_WARP_LOAD == 64,
int,
typename std::conditional<HALF_ELEMENTS_PER_WARP_LOAD == 128, int2, int4>::type>::type>::
type;
template<typename T, int ELEMENTS_PER_WARP_LOAD>
using Copy_t = Copy_half_t<sizeof(T) / sizeof(half) * ELEMENTS_PER_WARP_LOAD>;
template<typename T>
struct Float_as_int_ {
};
template<>
struct Float_as_int_<float> {
using Type = uint32_t;
};
template<>
struct Float_as_int_<__half> {
using Type = uint16_t;
};
using kernel_params_float = Segmented_topk_kernel_params<float, int32_t, 256, 2>;
using kernel_params_float_1 = Segmented_topk_kernel_params<float, int32_t, 256, 1>;
using kernel_params_half = Segmented_topk_kernel_params<__half, int32_t, 256, 4>;
using kernel_params_half_1 = Segmented_topk_kernel_params<__half, int32_t, 256, 1>;
///////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float to_float(uint32_t src)
{
return __int_as_float(src);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float to_float(uint16_t src)
{
__half dst = __ushort_as_half(src);
return __half2float(dst);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// sort one segment per cta
template<typename T_SCORE, int BLOCK_THREADS, int ELEMENTS_PER_THREAD>
__global__ void blockSortKernel(const T_SCORE* d_keys_in,
T_SCORE* d_keys_out,
const int32_t* d_values_in,
int32_t* d_values_out,
const int32_t* active_counts,
int num_items_,
int stride_items,
int num_segments)
{
// Specialize BlockRadixSort for a 1D block
typedef cub::BlockRadixSort<T_SCORE, BLOCK_THREADS, ELEMENTS_PER_THREAD, int32_t> BlockRadixSort;
// Allocate shared memory for BlockRadixSort
__shared__ typename BlockRadixSort::TempStorage temp_storage;
if (blockIdx.x >= num_segments) {
return;
}
int num_items = active_counts[blockIdx.x]; // > num_items_ ? num_items_ : active_counts[blockIdx.x];
if (num_items == 0) {
return;
}
// Obtain a segment of consecutive items that are blocked across threads
T_SCORE thread_keys[ELEMENTS_PER_THREAD];
int32_t thread_values[ELEMENTS_PER_THREAD];
int32_t block_offset = blockIdx.x * stride_items;
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_out + block_offset, thread_keys, num_items, 0);
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_values_out + block_offset, thread_values, num_items, -1);
__syncthreads();
// Collectively sort the keys and values among block threads
BlockRadixSort(temp_storage).SortDescendingBlockedToStriped(thread_keys, thread_values);
// Store output in striped fashion
cub::StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_out + block_offset, thread_keys, num_items);
cub::StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_values_out + block_offset, thread_values, num_items);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// block sort kernel
template<typename T_SCORE>
void blockSort(const T_SCORE* d_keys_in,
T_SCORE* d_keys_out,
const int32_t* d_values_in,
int32_t* d_values_out,
const int32_t* active_counts,
int num_items,
int stride_items,
int num_segments,
cudaStream_t stream)
{
if (num_items == 0) {
return;
}
int kernel_index = div_up(num_items, 128) - 1;
int warps_per_cta = (kernel_index + 1) * 128 / 32;
if (kernel_index > 7) {
kernel_index = 7 + div_up(num_items, 1024) - 1;
warps_per_cta = 1024 / 32;
}
assert(warps_per_cta <= 32);
dim3 block(warps_per_cta * 32);
dim3 grid(num_segments);
using kernel_func = void (*)(const T_SCORE* d_keys_in,
T_SCORE* d_keys_out,
const int32_t* d_values_in,
int32_t* d_values_out,
const int32_t* active_counts,
int num_items,
int stride_items,
int num_segments);
static const kernel_func kernel_funcs[] = {
&blockSortKernel<T_SCORE, 128, 1>,
&blockSortKernel<T_SCORE, 256, 1>,
&blockSortKernel<T_SCORE, 384, 1>,
&blockSortKernel<T_SCORE, 512, 1>,
&blockSortKernel<T_SCORE, 640, 1>,
&blockSortKernel<T_SCORE, 768, 1>,
&blockSortKernel<T_SCORE, 896, 1>,
&blockSortKernel<T_SCORE, 1024, 1>,
&blockSortKernel<T_SCORE, 1024, 2>,
&blockSortKernel<T_SCORE, 1024, 4>,
//&blockSortKernel<T_SCORE, 1024, 6>,
};
kernel_funcs[kernel_index]<<<grid, block, 0, stream>>>(
d_keys_in, d_keys_out, d_values_in, d_values_out, active_counts, num_items, stride_items, num_segments);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
struct BlockPrefixCallbackOp {
// Running prefix
int running_total;
// Constructor
__device__ BlockPrefixCallbackOp(uint32_t running_total): running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ int operator()(uint32_t block_aggregate)
{
uint32_t old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
#define DO_DEBUG_PRINT 0
// governs the split between regs and smem
constexpr float SMEM_FRACTION = 0.5F;
constexpr float P_EPSILON = 0.01F;
constexpr int MAX_TOP_K = 3072;
constexpr int WARP_SZ = 32;
template<typename Kernel_params, int ITEMS_PER_THREAD>
__global__ __launch_bounds__(Kernel_params::BLOCK_THREADS,
1) void segmented_top_p_single_pass(TopKPerSegmentParams params)
{
#if DO_DEBUG_PRINT
constexpr int debug_block_id = 26;
#endif
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
using Int_Key_Data_Type = typename Float_as_int_<Key_Data_Type>::Type;
// 4 fp16 keys or 2 fp32 keys
constexpr int KEYS_PER_LDG = Kernel_params::KEYS_PER_LDG;
typedef Copy_t<Key_Data_Type, WARP_SZ * KEYS_PER_LDG> copy_t;
union access_t {
copy_t v;
Int_Key_Data_Type x[KEYS_PER_LDG]; // supported size 1,2,4
};
constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
constexpr int ITEMS_PER_THREAD_IN_REGS = ITEMS_PER_THREAD * (1.0F - SMEM_FRACTION);
constexpr int ITEMS_PER_THREAD_IN_SMEM = ITEMS_PER_THREAD - ITEMS_PER_THREAD_IN_REGS;
#if DO_DEBUG_PRINT == 1
if (blockIdx.x == 0 && threadIdx.x == 0) {
printf("ITEMS_PER_THREAD, ITEMS_PER_THREAD_IN_REGS, ITEMS_PER_THREAD_IN_SMEM = %d, %d, %d\n",
ITEMS_PER_THREAD,
ITEMS_PER_THREAD_IN_REGS,
ITEMS_PER_THREAD_IN_SMEM);
}
#endif
constexpr int MIN_KEY = 0;
constexpr int ENABLED_PER_THREAD = (ITEMS_PER_THREAD + 32 - 1) / 32;
extern __shared__ int2 dynamic_smem[];
int2* smem_selected_elements = dynamic_smem;
Int_Key_Data_Type* smem_thread_items = reinterpret_cast<Int_Key_Data_Type*>(smem_selected_elements + MAX_TOP_K);
__shared__ unsigned int smem_selected_count;
// Specialize BlockScan type for our thread block
typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
// Specialize BlockScan type for our thread block
typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce;
__shared__ float smem_p_sum_total;
__shared__ union {
typename BlockScan::TempStorage scan;
typename BlockReduce::TempStorage reduce;
} temp_storage;
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
unsigned int old_selected_count;
uint32_t segment = blockIdx.y * gridDim.x + blockIdx.x;
// Preceding TopK has shortcutted this segment
if (params.gmem_begin_offsets[segment] == params.gmem_end_offsets[segment]) {
if (threadIdx.x == 0) {
params.gmem_active_count_per_segment[segment] = 1;
atomicMax(params.gmem_active_count_total, 1);
}
return;
}
Int_Key_Data_Type* gmem_src_keys = reinterpret_cast<Int_Key_Data_Type*>(params.gmem_src_keys);
Int_Key_Data_Type* gmem_dst_keys = reinterpret_cast<Int_Key_Data_Type*>(params.gmem_dst_keys);
int32_t* gmem_dst_vals = reinterpret_cast<int32_t*>(params.gmem_dst_vals);
constexpr int BITS_IN_KEY = sizeof(Key_Data_Type) * 8;
int items = params.num_items / params.num_segments;
int first_index = segment * items;
gmem_src_keys += first_index;
gmem_dst_keys += first_index;
gmem_dst_vals += first_index;
int index_limit = items;
Int_Key_Data_Type thread_items[ITEMS_PER_THREAD_IN_REGS] = {0};
// Load all keys into registers and smem
const int lane_id = threadIdx.x % WARP_SZ;
const int warp_id = threadIdx.x / WARP_SZ;
constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SZ;
access_t ZERO;
for (int i = 0; i < KEYS_PER_LDG; i++) {
ZERO.x[i] = MIN_KEY;
}
// registers
for (int iter = 0; iter < ITEMS_PER_THREAD_IN_REGS; iter++) {
int offset = (iter + threadIdx.x * ITEMS_PER_THREAD);
thread_items[iter] = (offset < index_limit) ? gmem_src_keys[offset] : MIN_KEY;
}
// shared memory
for (int c = warp_id; c < BLOCK_THREADS; c += NUM_WARPS) {
for (int iter = lane_id * KEYS_PER_LDG; iter < ITEMS_PER_THREAD_IN_SMEM; iter += WARP_SZ * KEYS_PER_LDG) {
int offset = iter + c * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS;
access_t val;
val.v = (offset < index_limit) ? *reinterpret_cast<copy_t*>(&gmem_src_keys[offset]) : ZERO.v;
for (int i = 0; i < KEYS_PER_LDG; i++) {
smem_thread_items[c + (iter + i) * BLOCK_THREADS] = val.x[i];
}
// smem_thread_items[c + iter * BLOCK_THREADS] = (offset < index_limit)? gmem_src_keys[offset] : MIN_KEY;
}
}
Int_Key_Data_Type select_mask = 0;
Int_Key_Data_Type save_mask = 0;
// Int_Key_Data_Type save_bit = 0;
// set to true when we finish with too few keys, so we go back to last_save_mask one more time
bool is_last_iter = false;
if (threadIdx.x == 0) {
smem_selected_count = 0;
old_selected_count = 0;
}
// iterate over bits.
// skip the first two bits,
// * bit 31 is the sign bit. all values are positive
// * bit 30 is only set for values >= 2, but the input consists only of values in
// the range of [0,1]
constexpr int START_BIT = BITS_IN_KEY - 1;
constexpr int SKIP_BITS = 2;
constexpr Int_Key_Data_Type ONE = (Int_Key_Data_Type)1;
uint32_t selected;
uint32_t sc;
float p_sum_total = 0.0F;
float old_p_sum_total = 0.0F;
uint32_t offset = 0;
for (Int_Key_Data_Type bit = START_BIT - SKIP_BITS; true; --bit) {
__syncthreads();
Int_Key_Data_Type bit_mask = select_mask | (ONE << bit);
uint32_t enabled[ENABLED_PER_THREAD] = {0};
float thread_sum = 0.0F;
for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) {
// check if all the bits from bit mask are contained in the thread_item. If yes, set respective
// bit of enabled
auto val = thread_items[item];
uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0);
// thread_sum += (is_enabled)? to_float(val) : 0.0F;
thread_sum += is_enabled * to_float(val);
enabled[item / 32] |= is_enabled << (item % 32);
}
for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) {
int idx = threadIdx.x + item * BLOCK_THREADS;
// int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x;
auto val = smem_thread_items[idx];
uint32_t is_enabled = uint32_t(((val ^ bit_mask) & bit_mask) == 0);
// thread_sum += (is_enabled)? to_float(val) : 0.0F;
thread_sum += is_enabled * to_float(val);
enabled[(ITEMS_PER_THREAD_IN_REGS + item) / 32] |= is_enabled << ((ITEMS_PER_THREAD_IN_REGS + item) % 32);
}
selected = 0;
#pragma unroll
for (int i = 0; i < ENABLED_PER_THREAD; i++) {
selected += __popc(enabled[i]);
}
float p_sum = BlockReduce(temp_storage.reduce).Sum(thread_sum);
if (threadIdx.x == 0) {
p_sum_total += p_sum;
smem_p_sum_total = p_sum_total;
}
__syncthreads();
p_sum_total = smem_p_sum_total;
__syncthreads();
BlockScan(temp_storage.scan).ExclusiveSum(selected, offset, prefix_op);
if (threadIdx.x == 0) {
smem_selected_count = prefix_op.running_total;
}
__syncthreads();
sc = smem_selected_count;
__syncthreads();
// float p_diff = params.top_p - p_sum_total;
float p_diff = p_sum_total - params.top_p;
if ((p_sum_total <= params.top_p + P_EPSILON && p_sum_total > 0)
|| (p_sum_total > params.top_p && sc <= MAX_TOP_K) || (bit == 0 && p_sum_total > 0) || is_last_iter) {
#if DO_DEBUG_PRINT == 1
__syncthreads();
if (threadIdx.x == 0 && blockIdx.x == debug_block_id) {
sc = smem_selected_count;
printf("bit %d bit_mask %d offset %d (%d, %d), sc = %d, p_sum = %f, p_sum_total = %f\n",
bit,
bit_mask,
offset,
blockIdx.x,
threadIdx.x,
sc,
p_sum,
p_sum_total);
}
__syncthreads();
#endif
for (int item = 0; item < ITEMS_PER_THREAD_IN_REGS; ++item) {
// last condition should not trigger with well trained weights, but we will get
// illegal mewmory access if we do not have one in those rare cases
if (enabled[item / 32] & (ONE << (item % 32)) && offset < MAX_TOP_K) {
smem_selected_elements[offset] =
make_int2(thread_items[item], item + threadIdx.x * ITEMS_PER_THREAD);
++offset;
thread_items[item] = MIN_KEY;
}
}
for (int item = 0; item < ITEMS_PER_THREAD_IN_SMEM; ++item) {
if (enabled[(item + ITEMS_PER_THREAD_IN_REGS) / 32] & (ONE << ((item + ITEMS_PER_THREAD_IN_REGS) % 32))
&& offset < MAX_TOP_K) {
int idx = threadIdx.x + item * BLOCK_THREADS;
// int idx = item + ITEMS_PER_THREAD_IN_SMEM * threadIdx.x;
// if (idx < params.num_items_per_segment_in_smem)
{
smem_selected_elements[offset] = make_int2(
smem_thread_items[idx], item + threadIdx.x * ITEMS_PER_THREAD + ITEMS_PER_THREAD_IN_REGS);
++offset;
smem_thread_items[idx] = MIN_KEY;
}
}
}
}
#if DO_DEBUG_PRINT == 1
if (threadIdx.x == 0 && blockIdx.x == debug_block_id) {
printf("!!!! bit %d bit_mask %d offset %d (%d, %d), sc = %d, p_sum = %f, p_sum_total = %f\n",
bit,
bit_mask,
offset,
blockIdx.x,
threadIdx.x,
sc,
p_sum,
p_sum_total);
}
#endif
if (p_diff <= P_EPSILON && p_diff >= 0 || (p_sum_total > params.top_p && sc <= MAX_TOP_K) || bit == 0) {
break;
}
// p > top_p
else if (p_diff > P_EPSILON) {
// There are too many bits in the current selection
// Save the current state and go to the next bit
// If there are not enough items left using the next bit
// it's necessary to restart here with the current bit not set
save_mask = bit_mask;
select_mask |= bit_mask;
if (threadIdx.x == 0) {
smem_selected_count = old_selected_count;
p_sum_total = old_p_sum_total;
prefix_op.running_total = old_selected_count;
}
}
else {
// sc < num_top_k branch
if (save_mask) {
select_mask = save_mask;
save_mask = 0;
}
if (threadIdx.x == 0) {
old_selected_count = smem_selected_count;
old_p_sum_total = p_sum_total;
}
}
}
__syncthreads();
// store data to global memory
sc = (p_sum_total < params.top_p) ? params.num_items / params.num_segments : smem_selected_count;
if (threadIdx.x == 0) {
params.gmem_active_count_per_segment[segment] = sc;
atomicMax(params.gmem_active_count_total, sc);
}
if (sc >= MAX_TOP_K) {
return;
}
for (int i = threadIdx.x; i < sc; i += blockDim.x) {
int2 selected_element = smem_selected_elements[i];
gmem_dst_keys[i] = selected_element.x;
gmem_dst_vals[i] = selected_element.y;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
int getSmemSizeAndCheck(const TopKPerSegmentContext& context, const TopKPerSegmentParams& params)
{
constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
int num_items_per_segment = params.num_items / params.num_segments;
constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT;
int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1;
int smem_size = MAX_TOP_K * sizeof(int2);
const int items_per_thread = (kernel_index + 1) * ITEMS_INCREMENT;
const int items_per_thread_in_regs = items_per_thread * (1.0F - SMEM_FRACTION);
const int items_per_thread_in_smem = items_per_thread - items_per_thread_in_regs;
smem_size += items_per_thread_in_smem * BLOCK_THREADS * sizeof(typename Float_as_int_<Key_Data_Type>::Type);
int keys_per_ldg = 2 * sizeof(Key_Data_Type) / 2;
if (smem_size + BLOCK_THREADS * sizeof(float) > (size_t)context.sm_shared_size || // dynamic + static memory
items_per_thread_in_regs + items_per_thread_in_smem != items_per_thread || params.top_p + P_EPSILON > 1.0F
|| items_per_thread_in_regs % keys_per_ldg != 0 || items_per_thread_in_smem % keys_per_ldg != 0
|| num_items_per_segment % keys_per_ldg != 0) {
return -1;
}
return smem_size;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int getSmemSizeAndCheck(const TopKPerSegmentContext& context,
const TopKPerSegmentParams& params,
const DType_t DT_SCORE)
{
int num_items_per_segment = params.num_items / params.num_segments;
if (DT_SCORE == kFLOAT) {
if (num_items_per_segment % 2 == 0) {
return getSmemSizeAndCheck<kernel_params_float>(context, params);
}
else {
return getSmemSizeAndCheck<kernel_params_float_1>(context, params);
}
}
else {
if (num_items_per_segment % 4 == 0) {
return getSmemSizeAndCheck<kernel_params_half>(context, params);
}
else {
return getSmemSizeAndCheck<kernel_params_half_1>(context, params);
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
void segmentedTopPSinglePass_dispatch(const TopKPerSegmentParams& params,
const TopKPerSegmentContext& context,
cudaStream_t stream)
{
constexpr int BLOCK_THREADS = Kernel_params::BLOCK_THREADS;
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
using Value_Data_Type = typename Kernel_params::Value_Data_Type;
int num_items_per_segment = params.num_items / params.num_segments;
constexpr int ITEMS_INCREMENT = Kernel_params::ITEMS_INCREMENT;
int kernel_index = div_up(num_items_per_segment, BLOCK_THREADS * ITEMS_INCREMENT) - 1;
#define KERNEL_RUN(INDEX) \
{ \
if (smem_size > 0) \
check_cuda_error( \
cudaFuncSetAttribute(segmented_top_p_single_pass<Kernel_params, ITEMS_INCREMENT*(INDEX + 1)>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
smem_size)); \
segmented_top_p_single_pass<Kernel_params, ITEMS_INCREMENT*(INDEX + 1)> \
<<<grid_dim, Kernel_params::BLOCK_THREADS, smem_size, stream>>>(params); \
}
int smem_size = getSmemSizeAndCheck<Kernel_params>(context, params);
dim3 grid_dim(params.num_segments, 1);
switch (kernel_index) {
case 0:
KERNEL_RUN(0) break;
case 1:
KERNEL_RUN(1) break;
case 2:
KERNEL_RUN(2) break;
case 3:
KERNEL_RUN(3) break;
case 4:
KERNEL_RUN(4) break;
case 5:
KERNEL_RUN(5) break;
case 6:
KERNEL_RUN(6) break;
case 7:
KERNEL_RUN(7) break;
default:
exit(1);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
void topPPerSegment_dispatch(const TopKPerSegmentContext& context,
TopKPerSegmentParams& params,
void* temp_storage,
size_t& temp_storage_bytes,
cudaStream_t stream)
{
using Key_Data_Type = typename Kernel_params::Key_Data_Type;
using Value_Data_Type = typename Kernel_params::Value_Data_Type;
if (temp_storage == nullptr) {
if (params.num_segments > 1) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(temp_storage,
temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
params.num_segments,
params.gmem_begin_offsets,
params.gmem_end_offsets,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
else {
cub::DeviceRadixSort::SortPairsDescending(temp_storage,
temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
temp_storage_bytes = div_up(temp_storage_bytes, 256) * 256;
// total active counts
temp_storage_bytes += div_up(sizeof(int), 256) * 256;
// storage for gmem_end_offsets
temp_storage_bytes += div_up(sizeof(int) * params.num_segments, 256) * 256;
return;
}
size_t cub_temp_storage_bytes =
temp_storage_bytes - div_up(sizeof(int), 256) * 256 - div_up(sizeof(int) * params.num_segments, 256) * 256;
void* cub_temp_storage = temp_storage;
params.gmem_active_count_total = reinterpret_cast<int*>((char*)temp_storage + cub_temp_storage_bytes);
params.gmem_active_count_per_segment =
reinterpret_cast<int*>((char*)params.gmem_active_count_total + div_up(sizeof(int), 256) * 256);
int num_items_per_segment = params.num_items / params.num_segments;
cudaMemsetAsync(params.gmem_active_count_total, 0, sizeof(int), stream);
cudaMemsetAsync(params.gmem_dst_keys, 0, params.num_items * sizeof(Key_Data_Type), stream);
segmentedTopPSinglePass_dispatch<Kernel_params>(params, context, stream);
int max_num_items = 0;
cudaMemcpyAsync(&max_num_items, params.gmem_active_count_total, sizeof(int), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
if (max_num_items >= MAX_TOP_K || max_num_items == 0) {
if (params.num_segments > 1) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage,
cub_temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
params.num_segments,
params.gmem_begin_offsets,
params.gmem_end_offsets,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
else {
cub::DeviceRadixSort::SortPairsDescending(cub_temp_storage,
cub_temp_storage_bytes,
reinterpret_cast<Key_Data_Type*>(params.gmem_src_keys),
reinterpret_cast<Key_Data_Type*>(params.gmem_dst_keys),
reinterpret_cast<Value_Data_Type*>(params.gmem_src_vals),
reinterpret_cast<Value_Data_Type*>(params.gmem_dst_vals),
params.num_items,
0,
sizeof(Key_Data_Type) * 8,
stream);
}
}
else {
// run at max supported value
blockSort<Key_Data_Type>((const Key_Data_Type*)(params.gmem_dst_keys),
(Key_Data_Type*)(params.gmem_dst_keys),
(const Value_Data_Type*)(params.gmem_dst_vals),
(Value_Data_Type*)(params.gmem_dst_vals),
params.gmem_active_count_per_segment,
max_num_items,
num_items_per_segment,
params.num_segments,
stream);
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
int topPPerSegment(const TopKPerSegmentContext& context,
TopKPerSegmentParams& params,
const DType_t DT_SCORE,
void* temp_storage,
size_t& temp_storage_bytes,
cudaStream_t stream)
{
int num_items_per_segment = params.num_items / params.num_segments;
if (DT_SCORE == kFLOAT) {
if (num_items_per_segment % 2 == 0) {
topPPerSegment_dispatch<kernel_params_float>(context, params, temp_storage, temp_storage_bytes, stream);
}
else {
topPPerSegment_dispatch<kernel_params_float_1>(context, params, temp_storage, temp_storage_bytes, stream);
}
}
else {
if (num_items_per_segment % 4 == 0) {
topPPerSegment_dispatch<kernel_params_half>(context, params, temp_storage, temp_storage_bytes, stream);
}
else {
topPPerSegment_dispatch<kernel_params_half_1>(context, params, temp_storage, temp_storage_bytes, stream);
}
}
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace segmented_topp_impl
__global__ void topPInitialize(
int* topp_id_val_buf, int* topp_offset_buf, int* begin_topp_offset_buf_, const int batch_size, const int n)
{
int tid = threadIdx.x;
int bid = blockIdx.x;
if (bid == 0) {
for (int i = tid; i < batch_size + 1; i += blockDim.x) {
topp_offset_buf[i] = i * n;
begin_topp_offset_buf_[i] = topp_offset_buf[i];
}
}
int index = tid + bid * blockDim.x;
while (index < batch_size * n) {
topp_id_val_buf[index] = index % n;
index += blockDim.x * gridDim.x;
}
}
void invokeTopPInitialize(int* topp_id_val_buf,
int* topp_offset_buf,
int* begin_topp_offset_buf_,
const size_t batch_size,
const int n,
cudaStream_t stream)
{
// n: the column number of logits_buffer for top_p sampling
topPInitialize<<<32, 512, 0, stream>>>(topp_id_val_buf, topp_offset_buf, begin_topp_offset_buf_, batch_size, n);
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void topp_beam_topk_kernel(const T* log_probs, // prob.
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const int vocab_size,
int* offset_buf,
int* begin_offset_buf,
const float top_p,
const float* top_ps,
const bool* skip_decode)
{
int thread_id = threadIdx.x;
int batch_id = blockIdx.x;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
float p_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, MAX_K> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -MAX_T_VAL;
}
#pragma unroll
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
int index = elem_id + batch_id * vocab_size;
partial.insert(log_probs[index], index);
}
TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);
if (thread_id == 0) {
begin_offset_buf[batch_id] = offset_buf[batch_id];
T sum_prob = (T)(0.0f);
#pragma unroll
for (int i = 0; i < MAX_K; i++) {
sum_prob += total.u[i];
}
if ((float)sum_prob >= p_threshold) {
begin_offset_buf[batch_id] += vocab_size;
int index = batch_id * vocab_size;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
topk_tmp_id_buf[index + i] = total.p[i] % vocab_size;
topk_tmp_val_buf[index + i] = total.u[i];
}
}
}
}
struct BlockPrefixCallbackOp {
// Running prefix
float running_total;
// Constructor
__device__ BlockPrefixCallbackOp(float running_total): running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ float operator()(float block_aggregate)
{
float old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
template<typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T* sorted_log_probs,
int* sorted_id_vals,
int* ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const int* begin_offset_buf,
const int* offset_buf,
const int vocab_size,
curandState_t* curandstate,
const float top_p,
const float* top_ps,
const int* end_ids,
const int batch_size,
const bool* skip_decode)
{
__shared__ int stop_shared;
__shared__ float rand_num_s;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
if (skip_decode != nullptr && skip_decode[batch_id]) {
return;
}
constexpr int WARP_SIZE = 32;
constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.x / WARP_SIZE;
const float prob_threshold = (top_ps != nullptr) ? top_ps[batch_id] : top_p;
if (threadIdx.x == 0) {
stop_shared = 0;
rand_num_s = curand_uniform(curandstate + blockIdx.x) * prob_threshold;
}
// if begin_offset_buf and offset_buf of sorting have same value,
// this means that we have find best one in beam_topK_kernel_for_topP
// and skip the sorting. So, we can skip then during sampling.
if (begin_offset_buf[batch_id] == offset_buf[batch_id]) {
if (tid == 0) {
int offset = batch_id * vocab_size;
ids[batch_id] = sorted_id_vals[offset];
if (cum_log_probs != nullptr || output_log_probs != nullptr) {
float lprob = logf(sorted_log_probs[offset]);
if (cum_log_probs != nullptr) {
cum_log_probs[batch_id] += lprob;
}
if (output_log_probs != nullptr) {
output_log_probs[batch_id] = lprob;
}
}
if (sequence_length != nullptr && finished_buf != nullptr) {
sequence_length[batch_id] =
finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0;
}
}
return;
}
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ uint32_t selected_shared[NUM_WARPS];
// Initialize running total
BlockPrefixCallbackOp prefix_op(0);
if (lane_id == 0) {
selected_shared[warp_id] = 0;
}
__syncthreads();
int offset = batch_id * vocab_size;
ids[batch_id] = sorted_id_vals[offset];
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int i_active = 0;
float thread_offset = 0;
for (int i = tid; i < end; i += BLOCK_SIZE) {
float thread_count = (i < vocab_size) ? (float)sorted_log_probs[offset + i] : 0.f;
BlockScan(temp_storage).InclusiveSum(thread_count, thread_offset, prefix_op);
uint32_t active_mask = __ballot_sync(0xFFFFFFFF, rand_num_s <= thread_offset);
i_active = i;
if (active_mask != 0) {
if (lane_id == 0) {
atomicAdd(&stop_shared, 1);
selected_shared[warp_id] = active_mask;
}
}
__syncthreads();
if (stop_shared > 0) {
break;
}
};
// select first active warp
bool skip = (selected_shared[warp_id] > 0) ? false : true;
for (int i = 0; i < warp_id; i++) {
if (selected_shared[i] != 0) {
skip = true;
}
}
if (!skip) {
int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]);
if (lane_id == active_lane_id) {
ids[batch_id] = sorted_id_vals[offset + i_active];
if (cum_log_probs != nullptr || output_log_probs != nullptr) {
float lprob = logf(sorted_log_probs[offset + i_active]);
if (cum_log_probs != nullptr) {
cum_log_probs[batch_id] += lprob;
}
if (output_log_probs != nullptr) {
output_log_probs[batch_id] = lprob;
}
}
if (sequence_length != nullptr && finished_buf != nullptr) {
sequence_length[batch_id] =
finished_buf[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished_buf[batch_id] = ids[batch_id] == end_ids[batch_id] ? 1 : 0;
}
}
}
}
template<typename T>
void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const T* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode)
{
// Here, we put batch size as an argument because the batch size of initialization
// and inference may be different due to pipeline parallelism.
const int vocab_size = vocab_size_padded;
const int block_size = 256;
size_t sorted_log_prob_buf_size = batch_size * vocab_size * sizeof(T); // type T
size_t sorted_id_vals_buf_size = batch_size * vocab_size * sizeof(int); // type int
sorted_log_prob_buf_size = div_up(sorted_log_prob_buf_size, 256) * 256;
sorted_id_vals_buf_size = div_up(sorted_id_vals_buf_size, 256) * 256;
void* cub_temp_storage = workspace;
T* sorted_log_probs = (T*)((char*)cub_temp_storage + cub_temp_storage_size);
int* sorted_id_vals = (int*)((char*)sorted_log_probs + sorted_log_prob_buf_size);
bool do_radix_sort = (ENABLE_SINGLE_PASS_TOP_P == 0 || max_top_p >= SINGLE_PASS_THRESHOLD);
int smem_size = -1;
segmented_topp_impl::TopKPerSegmentContext context;
segmented_topp_impl::TopKPerSegmentParams params;
segmented_topp_impl::DType_t dataTypeKind =
(std::is_same<T, float>::value) ? segmented_topp_impl::kFLOAT : segmented_topp_impl::kHALF;
if (!do_radix_sort) {
FT_CHECK(cuda_device_prop != nullptr);
memset(&context, 0, sizeof(context));
context.sm_count = cuda_device_prop->multiProcessorCount;
context.sm_shared_size = cuda_device_prop->sharedMemPerMultiprocessor;
context.sm_version = cuda_device_prop->major * 100 + cuda_device_prop->minor * 10;
memset(&params, 0, sizeof(params));
params.gmem_src_keys = reinterpret_cast<void*>(const_cast<T*>(log_probs));
params.gmem_dst_keys = sorted_log_probs;
params.gmem_src_vals = reinterpret_cast<void*>(const_cast<int*>(id_vals));
params.gmem_dst_vals = reinterpret_cast<void*>(sorted_id_vals);
params.gmem_begin_offsets = begin_offset_buf;
params.gmem_end_offsets = offset_buf + 1;
params.workspace = nullptr;
params.num_items = vocab_size * batch_size;
params.num_segments = batch_size;
params.top_p = max_top_p;
params.confidence_threshold = 0.0F;
smem_size = getSmemSizeAndCheck(context, params, dataTypeKind);
do_radix_sort = smem_size < 0;
}
if (do_radix_sort) {
if (workspace == nullptr) {
check_cuda_error(
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
cub_temp_storage_size,
log_probs,
(T*)nullptr,
id_vals,
(int*)nullptr,
vocab_size * batch_size,
batch_size,
begin_offset_buf,
offset_buf + 1,
0, // begin_bit
sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8
stream)); // cudaStream_t
cub_temp_storage_size = div_up(cub_temp_storage_size, 256) * 256;
workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size;
return;
}
topp_beam_topk_kernel<T, 1, block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
sorted_id_vals,
sorted_log_probs,
vocab_size,
offset_buf,
begin_offset_buf,
max_top_p,
top_ps,
skip_decode);
check_cuda_error(
cub::DeviceSegmentedRadixSort::SortPairsDescending(cub_temp_storage,
cub_temp_storage_size,
log_probs,
sorted_log_probs,
id_vals,
sorted_id_vals,
vocab_size * batch_size,
batch_size,
begin_offset_buf,
offset_buf + 1,
0, // begin_bit
sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8
stream)); // cudaStream_t
}
else {
if (workspace == nullptr) {
segmented_topp_impl::topPPerSegment(
context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream);
workspace_size = sorted_log_prob_buf_size + sorted_id_vals_buf_size + cub_temp_storage_size;
return;
}
else {
topp_beam_topk_kernel<T, 1, block_size><<<batch_size, block_size, 0, stream>>>(log_probs,
sorted_id_vals,
sorted_log_probs,
vocab_size,
offset_buf,
begin_offset_buf,
max_top_p,
top_ps,
skip_decode);
segmented_topp_impl::topPPerSegment(
context, params, dataTypeKind, cub_temp_storage, cub_temp_storage_size, stream);
}
}
constexpr int SAMPLING_BLOCK_SIZE = 256;
dim3 grid(batch_size);
topp_sampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sorted_log_probs,
sorted_id_vals,
output_ids,
sequence_length,
finished_buf,
cum_log_probs,
output_log_probs,
begin_offset_buf,
offset_buf + 1,
vocab_size,
curandstate,
max_top_p,
top_ps,
end_ids,
batch_size,
skip_decode);
}
template void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const float* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const half* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template<typename T>
void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const T* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode)
{
invokeBatchTopPSampling(workspace,
workspace_size,
cub_temp_storage_size,
output_ids,
sequence_length,
finished_buf,
cum_log_probs,
output_log_probs,
log_probs,
id_vals,
offset_buf,
begin_offset_buf,
curandstate,
batch_size,
vocab_size_padded,
end_ids,
top_p,
nullptr,
stream,
cuda_device_prop,
skip_decode);
}
template void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const float* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const half* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template<typename T>
__global__ void
addBiasSoftMax(T* logits, const T* bias, const int* end_ids, const bool* finished, const int n_padded, const int n)
{
int bid = blockIdx.x;
bool finish = (finished != nullptr) ? finished[bid] : false;
int offset = bid * n_padded;
float max_val = -1 * FLT_MAX;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
__shared__ float s_max_val;
__shared__ float s_sum_val;
for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
if (tid < n) {
if (finish) {
logits[offset + tid] = (tid == end_ids[bid]) ? MAX_T_VAL : -MAX_T_VAL;
}
else {
T bias_val = (bias != nullptr) ? bias[tid] : (T)0.0f;
logits[offset + tid] += bias_val;
}
}
else {
logits[offset + tid] = -MAX_T_VAL;
}
max_val = max(max_val, (float)logits[offset + tid]);
}
max_val = blockReduceMax<float>((float)max_val);
if (threadIdx.x == 0) {
s_max_val = max_val;
}
__syncthreads();
float sum_val = 0.0f;
for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
logits[offset + tid] = __expf((float)logits[offset + tid] - s_max_val);
sum_val += (float)logits[offset + tid];
}
sum_val = blockReduceSum<float>(sum_val);
if (threadIdx.x == 0) {
s_sum_val = sum_val;
}
__syncthreads();
for (int tid = threadIdx.x; tid < n_padded; tid += blockDim.x) {
logits[offset + tid] = ((float)logits[offset + tid] / (s_sum_val + 1e-6f));
}
}
template<typename T>
void invokeAddBiasSoftMax(T* logits,
const T* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream)
{
dim3 grid(m);
dim3 block(min(n, 1024));
/*n is the vocab_size, e.g., 30000, 7000.... vocab_size is usually very big. */
addBiasSoftMax<<<grid, block, 0, stream>>>(logits, bias, end_ids, finished, n_padded, n);
}
template void invokeAddBiasSoftMax(float* logits,
const float* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream);
template void invokeAddBiasSoftMax(half* logits,
const half* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream);
__global__ void computeToppDecay(float* runtime_top_p,
const float* runtime_initial_top_p,
const int* output_ids,
const float* top_p_decay,
const float* top_p_min,
const int32_t* top_p_reset_ids,
const int local_batch_size)
{
/**
* @brief Compute the topp decay by https://arxiv.org/pdf/2206.04624.pdf
* In short, the formula is
* runtime_top_p = max(runtime_top_p * top_p_decay, top_p_min)
* If generating the top_p_reset_ids, then reset the runtime_top_p.
*
* \param runtime_top_p [local_batch_size]
* \param runtime_initial_top_p [local_batch_size]
* \param output_ids [local_batch_size]
* \param top_p_decay [local_batch_size]
* \param top_p_min [local_batch_size]
* \param top_p_reset_ids [local_batch_size]
* \param local_batch_size
*
*/
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (output_ids[idx] == top_p_reset_ids[idx]) {
runtime_top_p[idx] = runtime_initial_top_p[idx];
}
else {
runtime_top_p[idx] = max(runtime_top_p[idx] * top_p_decay[idx], top_p_min[idx]);
}
}
void invokeComputeToppDecay(float* runtime_top_p,
const float* runtime_initial_top_p,
const int* output_ids,
const float* top_p_decay,
const float* top_p_min,
const int32_t* top_p_reset_ids,
const int local_batch_size,
cudaStream_t stream)
{
dim3 block(min(local_batch_size, 512));
dim3 grid((local_batch_size + block.x - 1) / block.x);
computeToppDecay<<<grid, block, 0, stream>>>(
runtime_top_p, runtime_initial_top_p, output_ids, top_p_decay, top_p_min, top_p_reset_ids, local_batch_size);
}
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <curand_kernel.h>
namespace fastertransformer {
void invokeTopPInitialize(int* topp_id_val_buf,
int* topp_offset_buf,
int* begin_topp_offset_buf_,
const size_t batch_size,
const int n,
cudaStream_t stream);
template<typename T>
void invokeTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const T* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float top_p,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template<typename T>
void invokeBatchTopPSampling(void* workspace,
size_t& workspace_size,
size_t& cub_temp_storage_size,
int* output_ids,
int* sequence_length,
bool* finished_buf,
float* cum_log_probs,
float* output_log_probs,
const T* log_probs,
const int* id_vals,
int* offset_buf,
int* begin_offset_buf,
curandState_t* curandstate,
const int batch_size,
const size_t vocab_size_padded,
const int* end_ids,
const float max_top_p,
const float* top_ps,
cudaStream_t stream,
cudaDeviceProp* cuda_device_prop,
const bool* skip_decode);
template<typename T>
void invokeAddBiasSoftMax(T* logits,
const T* bias,
const int* end_ids,
const bool* finished,
const int m,
const int n_padded,
const int n,
cudaStream_t stream);
namespace segmented_topp_impl {
enum DType_t {
kFLOAT,
kHALF,
kINT8
};
template<typename Key_Data_Type_ = float,
typename Value_Data_Type_ = int32_t,
int BLOCK_THREADS_ = 256,
int KEYS_PER_LDG_ = 1>
struct Segmented_topk_kernel_params {
typedef Key_Data_Type_ Key_Data_Type;
typedef Value_Data_Type_ Value_Data_Type;
enum {
BLOCK_THREADS = BLOCK_THREADS_
};
enum {
ITEMS_INCREMENT = 32
};
// enum { KEYS_PER_LDG = 2 * 4 / sizeof(Key_Data_Type_) };
enum {
KEYS_PER_LDG = KEYS_PER_LDG_
};
};
struct TopKPerSegmentContext {
TopKPerSegmentContext(): sm_count(0), sm_shared_size(0), sm_version(0){};
int sm_count;
int sm_shared_size;
int sm_version;
};
struct TopKPerSegmentParams {
// input/output keys and values
void *gmem_src_keys, *gmem_dst_keys, *gmem_dst_vals;
// not used in the custom implementation
void* gmem_src_vals;
// int array of size num_segments
int* gmem_active_count_per_segment;
int* gmem_active_count_total;
int* gmem_begin_offsets;
// gmem_end_offsets will be populated
int* gmem_end_offsets;
void* workspace;
// total number of items for all segments
int num_items;
int num_segments;
// top_k per segment
int num_top_k;
float top_p;
float confidence_threshold;
};
int topPPerSegment(const TopKPerSegmentContext& context,
TopKPerSegmentParams& params,
const DType_t DT_SCORE,
void* temp_storage,
size_t& temp_storage_bytes,
cudaStream_t stream);
} // namespace segmented_topp_impl
void invokeComputeToppDecay(float* runtime_top_p,
const float* runtime_initial_top_p,
const int* output_ids,
const float* top_p_decay,
const float* top_p_min,
const int32_t* top_p_reset_ids,
const int local_batch_size,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/kernels/stop_criteria_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "src/fastertransformer/utils/memory_utils.h"
namespace fastertransformer {
__global__ void stop_words_criterion(const int* output_ids,
const int* parent_ids,
const int* stop_words,
bool* finished,
size_t id_offset,
size_t stop_words_len,
int batch_size,
int beam_width,
int step)
{
const int id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_idx = blockIdx.y / beam_width;
const int beam_idx = blockIdx.y % beam_width;
const int* base_stop_words = stop_words + batch_idx * 2 * stop_words_len;
const int* base_offsets = base_stop_words + stop_words_len;
if (id >= stop_words_len || base_offsets[id] < 0) {
return;
}
const int item_end = base_offsets[id];
const int item_start = (id > 0) ? base_offsets[id - 1] : 0;
const int item_size = item_end - item_start;
/* The single-token case unconditionally bans the token */
bool should_stop = false;
/* Enough previously generated tokens to look for a match */
if (step + 1 >= item_size) {
should_stop = true;
int parent_id = beam_idx;
const bool gather_beam = beam_width > 1;
for (int token_idx = item_size - 1; token_idx >= 0; token_idx--) {
const int previous_token = output_ids[(step - (item_size - 1) + token_idx) * batch_size * beam_width
+ id_offset + batch_idx * beam_width + parent_id];
if (previous_token != base_stop_words[item_start + token_idx]) {
should_stop = false;
break;
}
if (gather_beam) {
parent_id = parent_ids[(step - (item_size - 1) + token_idx) * beam_width * batch_size + id_offset
+ batch_idx * beam_width + parent_id];
if (parent_id < 0 || parent_id >= beam_width) {
should_stop = false;
break;
}
}
}
}
if (should_stop) {
finished[batch_idx * beam_width + beam_idx] = true;
}
}
void invokeStopWordsCriterion(const int* output_ids,
const int* parent_ids,
const int* stop_words,
bool* finished,
size_t id_offset,
size_t stop_words_len,
int batch_size,
int beam_width,
int step,
cudaStream_t stream)
{
FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
// Check if we have sampled a word from the stop_words list. If so, stop the sequence.
dim3 block, grid;
block.x = min(((stop_words_len + 32 - 1) / 32) * 32, 256UL);
grid.x = (stop_words_len + block.x - 1) / block.x;
grid.y = batch_size * beam_width;
stop_words_criterion<<<grid, block, 0, stream>>>(
output_ids, parent_ids, stop_words, finished, id_offset, stop_words_len, batch_size, beam_width, step);
sync_check_cuda_error();
}
__global__ void length_criterion(bool* finished,
bool* should_stop,
int* finished_sum,
const uint32_t* sequence_limit_length,
int batch_size,
int beam_width,
int step)
{
int thread_finished_count = 0;
for (int index = threadIdx.x; index < batch_size * beam_width; index += blockDim.x) {
const int batch_idx = index / beam_width;
finished[index] |= step >= sequence_limit_length[batch_idx];
thread_finished_count += finished[index] ? 1 : 0;
}
int block_finished_count = 0;
if (blockDim.x <= 32) {
block_finished_count = warpReduceSum(thread_finished_count);
}
else {
block_finished_count = blockReduceSum(thread_finished_count);
}
__syncthreads();
if (threadIdx.x == 0) {
finished_sum[0] = block_finished_count;
}
}
void invokeLengthCriterion(bool* finished,
bool* should_stop,
int* h_pinned_finished_sum_,
const uint32_t* sequence_limit_length,
int batch_size,
int beam_width,
int step,
cudaStream_t stream)
{
// Check if we have attained the sequence length limit. If so, stop the sequence.
// In addition, check if all sequences are stopped and return the result in should_stop
FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
dim3 block{min(512, uint32_t(batch_size * beam_width))};
dim3 grid{1};
h_pinned_finished_sum_[0] = -1;
length_criterion<<<grid, block, 0, stream>>>(
finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step);
while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {};
sync_check_cuda_error();
*should_stop = h_pinned_finished_sum_[0] == batch_size * beam_width;
}
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
namespace fastertransformer {
void invokeStopWordsCriterion(const int* output_ids,
const int* parent_ids,
const int* stop_words,
bool* finished,
size_t id_offset,
size_t stop_words_len,
int batch_size,
int beam_width,
int step,
cudaStream_t stream);
void invokeLengthCriterion(bool* finished,
bool* should_stop,
int* finished_sum,
const uint32_t* sequence_limit_length,
int batch_size,
int beam_width,
int step,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/kernels/unfused_attention_kernels.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/utils/cuda_utils.h"
namespace fastertransformer {
__inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
{
return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4;
}
template<typename T>
__global__ void addQKVBiasIA3Transpose(T* q_out,
T* k_out,
T* v_out,
const T* __restrict q_in,
const T* __restrict bias_q,
const T* __restrict k_in,
const T* __restrict bias_k,
const T* __restrict v_in,
const T* __restrict bias_v,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head)
{
const int n = head_num * size_per_head;
const int batch_id = blockIdx.x;
const int word_id = blockIdx.y;
const int row_id = batch_id * seq_len + word_id;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) {
const int head_id = col_id / size_per_head;
const int size_id = col_id % size_per_head;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head
+ word_id * size_per_head + size_id;
const int src_id = row_id * n + col_id;
T q = ldg(&q_in[src_id]);
q_out[target_id] = add(q, ldg(&bias_q[col_id]));
T k = add(ldg(&k_in[src_id]), ldg(&bias_k[col_id]));
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + col_id];
}
k_out[target_id] = k;
T v = add(ldg(&v_in[src_id]), ldg(&bias_v[col_id]));
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + col_id];
}
v_out[target_id] = v;
}
}
template<typename T>
__global__ void QKVIA3Transpose(T* q_out,
T* k_out,
T* v_out,
const T* __restrict q_in,
const T* __restrict k_in,
const T* __restrict v_in,
const int* ia3_tasks,
const T* __restrict ia3_key_weights,
const T* __restrict ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head)
{
const int n = head_num * size_per_head;
const int batch_id = blockIdx.x;
const int word_id = blockIdx.y;
const int row_id = batch_id * seq_len + word_id;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) {
const int head_id = col_id / size_per_head;
const int size_id = col_id % size_per_head;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head
+ word_id * size_per_head + size_id;
const int src_id = row_id * n + col_id;
q_out[target_id] = ldg(&q_in[src_id]);
T k = ldg(&k_in[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + col_id];
}
k_out[target_id] = k;
T v = ldg(&v_in[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + col_id];
}
v_out[target_id] = v;
}
}
template<typename T>
void invokeAddQKVBiasIA3Transpose(T* q_buf,
T* k_buf,
T* v_buf,
T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream)
{
const int k = head_num * size_per_head;
dim3 grid(batch_size, seq_len);
bool is_add_bias = bias_Q != nullptr;
if (sizeof(T) == 4 || k % 2 != 0) {
dim3 block(min(k, 512));
if (is_add_bias) {
addQKVBiasIA3Transpose<T><<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
Q,
bias_Q,
K,
bias_K,
V,
bias_V,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head);
}
else {
QKVIA3Transpose<T><<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
Q,
K,
V,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head);
}
sync_check_cuda_error();
}
else {
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
dim3 block(min(k / 2, 512));
if (is_add_bias) {
addQKVBiasIA3Transpose<T2><<<grid, block, 0, stream>>>((T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
(const T2*)Q,
(const T2*)bias_Q,
(const T2*)K,
(const T2*)bias_K,
(const T2*)V,
(const T2*)bias_V,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2);
}
else {
QKVIA3Transpose<T2><<<grid, block, 0, stream>>>((T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
(const T2*)Q,
(const T2*)K,
(const T2*)V,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2);
}
sync_check_cuda_error();
}
}
#define INSTANTIATEADDQKVBIASIA3TRANSPOSE(T) \
template void invokeAddQKVBiasIA3Transpose(T* q_buf, \
T* k_buf, \
T* v_buf, \
T* Q, \
const T* bias_Q, \
T* K, \
const T* bias_K, \
T* V, \
const T* bias_V, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int* ia3_tasks, \
const T* ia3_key_weights, \
const T* ia3_value_weights, \
cudaStream_t stream)
INSTANTIATEADDQKVBIASIA3TRANSPOSE(float);
INSTANTIATEADDQKVBIASIA3TRANSPOSE(half);
#ifdef ENABLE_BF16
INSTANTIATEADDQKVBIASIA3TRANSPOSE(__nv_bfloat16);
#endif
#undef INSTANTIATEADDQKVBIASTRANSPOSE
template<typename T, typename T_IN, int ITEMS_PER_THREAD>
__global__ void softmax_kernel(T* attn_score,
const T_IN* qk,
const T* attn_mask,
const T* linear_bias_slopes,
const int batch_size,
const int head_num,
const int q_length,
const int k_length,
const float qk_scale)
{
// attn_score, [batch_size, num_heads, q_length, k_length]
// qk, [batch_size, num_heads, q_length, k_length]
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]
const int bi = blockIdx.y; // Batch index.
const int hi = blockIdx.z; // Head index.
__shared__ float s_mean, s_max;
const float linear_bias_slope = linear_bias_slopes != nullptr ? (float)linear_bias_slopes[hi] : 0.0f;
// Loop along with Q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {
float data[ITEMS_PER_THREAD];
int qk_offset;
float local_max = -1e20f;
// Loop along with K dimension.
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
int ki = blockDim.x * i + threadIdx.x; // Index of K dimension.
qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + ki;
float qk_val = static_cast<float>(qk[qk_offset]);
float qk_bias = 0.0f;
if (linear_bias_slopes != nullptr) {
// We don't handle the upper diagonal (ki > qi) separately, whose values
// are negligible due to the negative infinity mask. And it matches with
// the HF's implementation.
qk_bias += static_cast<float>(linear_bias_slope * (ki - qi));
}
int mask_offset = (bi * q_length + qi) * k_length + ki;
float mask_val = static_cast<float>(ldg(&attn_mask[mask_offset]));
qk_bias += (1.0f - mask_val) * -10000.0f;
data[i] = qk_scale * qk_val + qk_bias;
local_max = fmax(local_max, data[i]);
}
float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0;
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
data[i] = __expf(data[i] - s_max);
local_sum += data[i];
}
float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + blockDim.x * i + threadIdx.x;
attn_score[qk_offset] = (T)(data[i] * s_mean);
}
}
}
template<typename T, int ITEMS_PER_THREAD>
__global__ void softmax_kernel_h2(T* attn_score,
const T* qk_buf,
const T* attn_mask,
const T* linear_bias_slopes,
const int batch_size,
const int head_num,
const int q_length,
const int k_length,
const T qk_scale)
{
// attn_score, [batch_size, num_heads, q_length, k_length]
// qk, [batch_size, num_heads, q_length, k_length]
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]
using T2 = typename TypeConverter<T>::Type;
T2* attn_score_h2 = reinterpret_cast<T2*>(attn_score);
const T2* qk_buf_h2 = reinterpret_cast<const T2*>(qk_buf);
const T2* attn_mask_h2 = reinterpret_cast<const T2*>(attn_mask);
const int bi = blockIdx.y; // Batch index
const int hi = blockIdx.z; // Head index.
__shared__ float s_mean, s_max;
// Constant values that will be used repeately in the q/k loop.
const T2 ONE = cuda_cast<T2>(1.0f);
const T2 ZERO = cuda_cast<T2>(0.0f);
const T2 NEG_INFTY = cuda_cast<T2>(-10000.0f);
// The normalization factor of QK.
const T2 qk_scale_h2 = cuda_cast<T2>(qk_scale);
// The slope of a linear position bias of the current attention head.
const T2 linear_bias_slope = linear_bias_slopes != nullptr ? cuda_cast<T2>(linear_bias_slopes[hi]) : ZERO;
// Loop over q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {
T2 data[ITEMS_PER_THREAD];
int qk_offset;
float local_max = -1e20f;
// Loop over k dimension.
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
// The half of the index of k dimension. We will use the elements at {2 * ki, 2 * ki + 1}.
int ki = blockDim.x * i + threadIdx.x;
qk_offset = ((bi * head_num + hi) * q_length + qi) * (k_length / 2) + ki;
int mask_offset = (bi * q_length + qi) * (k_length / 2) + ki;
// The value of QK^T matrix at (qi, ki).
T2 qk = qk_buf_h2[qk_offset];
// The bias value to the position (qi, ki) including both mask and positional bias.
T2 qk_bias = ZERO;
if (linear_bias_slopes != nullptr) {
// The position bias depends on the distance between qi/ki and is zero if qi >= 2*ki
// or qi >= 2*ki+1. For T2 vectorization, we should handle every two elements along
// with k-dim simultaneously. To do this, we check qi / 2 > ki at ones instead of
// qi >= 2*ki or 2*ki+1. It works because an diagonal element for an odd qi will be
// zero due to slope * (qi - 2*ki+1) = 0. Thus, we don't handle the upper diagonal
// separately, whose values are negligible due to the negative infinity mask.
T2 dist(2.0f * ki - qi, 2.0f * ki + 1 - qi);
qk_bias = hadd2<T2>(qk_bias, hmul2<T2>(linear_bias_slope, dist));
}
T2 mask_val = ldg(&attn_mask_h2[mask_offset]);
qk_bias = hadd2<T2>(qk_bias, hmul2<T2>(hsub2<T2>(ONE, mask_val), NEG_INFTY));
data[i] = hadd2<T2>(hmul2<T2>(qk, qk_scale_h2), qk_bias);
local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y));
}
float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0.0f;
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
data[i] = hexp2<T2>(hsub2<T2>(data[i], cuda_cast<T2>(s_max)));
local_sum += (float)(data[i].x + data[i].y);
}
float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
qk_offset = ((bi * head_num + hi) * q_length + qi) * (k_length / 2) + blockDim.x * i + threadIdx.x;
attn_score_h2[qk_offset] = hmul2<T2>(data[i], cuda_cast<T2>(s_mean));
}
}
}
template<typename T, int K_ITEMS_PER_THREAD, int Q_ITEMS_PER_THREAD>
__global__ void softmax_kernel_h2_v2(T* attn_score,
const T* qk_buf,
const T* attn_mask,
const T* linear_bias_slopes,
const int batch_size,
const int head_num,
const int q_length,
const int k_length,
const T scalar)
{
// attn_score, [batch_size, num_heads, q_length, k_length]
// qk, [batch_size, num_heads, q_length, k_length]
// attn_mask, [batch_size, q_length, k_length]
// linear_bias_slopes, [num_heads]
using T2 = typename TypeConverter<T>::Type;
// QK^T matrix of shape (batch_size, head_num, q_length, k_length / 2)
T2* attn_score_h2 = reinterpret_cast<T2*>(attn_score);
const T2* qk_buf_h2 = reinterpret_cast<const T2*>(qk_buf);
const T2* attn_mask_h2 = reinterpret_cast<const T2*>(attn_mask);
const int bi = blockIdx.y; // Batch index
const int hi = blockIdx.z; // Head index.
// Constant values that will be used repeately in the q/k loop.
const T2 ONE = cuda_cast<T2>(1.0f);
const T2 ZERO = cuda_cast<T2>(0.0f);
const T2 NEG_INFTY = cuda_cast<T2>(-10000.0f);
// The normalization factor of QK.
const T2 qk_scale = cuda_cast<T2>(scalar);
// The slope of a linear position bias of the current attention head.
const T2 linear_bias_slope = linear_bias_slopes != nullptr ? cuda_cast<T2>(linear_bias_slopes[hi]) : ZERO;
__shared__ float s_sum[Q_ITEMS_PER_THREAD], s_max[Q_ITEMS_PER_THREAD];
// Loop over q dimension.
for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x * Q_ITEMS_PER_THREAD) {
T2 data[Q_ITEMS_PER_THREAD][K_ITEMS_PER_THREAD];
int qk_offset[Q_ITEMS_PER_THREAD];
float local_max[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
local_max[j] = -1e20f;
}
// Loop over k dimension.
const int Q_ITEMS = min((q_length - qi + gridDim.x - 1) / gridDim.x, Q_ITEMS_PER_THREAD);
for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
// The half of the index of k dimension. We will use the elements at {2 * ki, 2 * ki + 1}.
int ki = blockDim.x * i + threadIdx.x;
int mask_offset[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
qk_offset[j] = ((bi * head_num + hi) * q_length + qi + j * gridDim.x) * (k_length / 2) + ki;
mask_offset[j] = (bi * q_length + qi + j * gridDim.x) * (k_length / 2) + ki;
}
T2 mask_val[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
mask_val[j] = ldg(&attn_mask_h2[mask_offset[j]]);
}
T2 qk[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
qk[j] = qk_buf_h2[qk_offset[j]];
}
T2 pos_bias[Q_ITEMS_PER_THREAD];
if (linear_bias_slopes != nullptr) {
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
// The position bias depends on the distance between qi/ki and is zero if qi >= 2*ki
// or qi >= 2*ki+1. For T2 vectorization, we should handle every two elements along
// with k-dim simultaneously. To do this, we check qi / 2 > ki at ones instead of
// qi >= 2*ki or 2*ki+1. It works because an diagonal element for an odd qi will be
// zero due to slope * (qi - 2*ki+1) = 0. Thus, we don't handle the upper diagonal
// separately, whose values are negligible due to the negative infinity mask.
int qidx = qi + j * gridDim.x;
T2 dist(2.0f * ki - qidx, 2.0f * ki + 1 - qidx);
pos_bias[j] = hmul2<T2>(linear_bias_slope, dist);
}
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
mask_val[j] = hmul2<T2>(hsub2<T2>(ONE, mask_val[j]), NEG_INFTY);
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
T2 val = hadd2<T2>(hmul2<T2>(qk_scale, qk[j]), mask_val[j]);
if (linear_bias_slopes != nullptr) {
val = hadd2<T2>(val, pos_bias[j]);
}
data[j][i] = val;
local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y));
}
}
if (blockDim.x <= 32) {
warpReduceMaxV2<float, Q_ITEMS_PER_THREAD>(local_max);
}
else {
blockReduceMaxV2<float, Q_ITEMS_PER_THREAD>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
s_max[j] = local_max[j];
}
}
__syncthreads();
float local_sum[Q_ITEMS_PER_THREAD];
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
local_sum[j] = {0.f};
}
for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
#pragma unroll
for (int j = 0; j < Q_ITEMS; ++j) {
data[j][i] = hexp2<T2>(hsub2<T2>(data[j][i], cuda_cast<T2>(s_max[j])));
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
local_sum[j] += (float)(data[j][i].x + data[j][i].y);
}
}
if (blockDim.x <= 32) {
warpReduceSumV2<float, Q_ITEMS_PER_THREAD>(local_sum);
}
else {
blockReduceSumV2<float, Q_ITEMS_PER_THREAD>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
qk_offset[j] = ((bi * head_num + hi) * q_length + qi + j * gridDim.x) * (k_length / 2) + blockDim.x * i
+ threadIdx.x;
}
#pragma unroll
for (int j = 0; j < Q_ITEMS; j++) {
attn_score_h2[qk_offset[j]] = hmul2<T2>(data[j][i], cuda_cast<T2>(s_sum[j]));
}
}
}
}
#define LAUNCH_MAKSED_SOFTMAX_(T_, ITEMS_PER_THREAD) \
block.x /= ITEMS_PER_THREAD; \
block.x = (block.x + 31) / 32 * 32; \
assert(block.x <= 1024); \
if (is_half2) { \
if (grid.x % 4 == 0) { \
grid.x /= 4; \
softmax_kernel_h2_v2<T_, ITEMS_PER_THREAD, 4> \
<<<grid, block, 0, stream>>>((T_*)param.attention_score, \
(const T_*)param.qk, \
(const T_*)param.attention_mask, \
(const T_*)param.linear_bias_slopes, \
param.batch_size, \
param.num_heads, \
param.q_length, \
param.k_length, \
(const T_)param.qk_scale); \
} \
else { \
softmax_kernel_h2<T_, ITEMS_PER_THREAD><<<grid, block, 0, stream>>>((T_*)param.attention_score, \
(const T_*)param.qk, \
(const T_*)param.attention_mask, \
(const T_*)param.linear_bias_slopes, \
param.batch_size, \
param.num_heads, \
param.q_length, \
param.k_length, \
(const T_)param.qk_scale); \
} \
} \
else { \
softmax_kernel<T, T_IN, ITEMS_PER_THREAD><<<grid, block, 0, stream>>>(param.attention_score, \
param.qk, \
param.attention_mask, \
param.linear_bias_slopes, \
param.batch_size, \
param.num_heads, \
param.q_length, \
param.k_length, \
param.qk_scale); \
}
#define LAUNCH_MAKSED_SOFTMAX(ITEMS_PER_THREAD) LAUNCH_MAKSED_SOFTMAX_(half, ITEMS_PER_THREAD)
template<typename T, typename T_IN>
void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream)
{
// attention_score, (batch_size, head_num, q_length, k_length), softmax output.
// qk, (batch_size, head_num, q_length, k_length), QK^T.
// attention_mask, (batch_size, q_length, k_length), attention mask.
// linear_bias_slopes, (head_num,) the slopes of the linear position bias.
dim3 grid(param.q_length, param.batch_size, param.num_heads);
if (param.batch_size * param.num_heads > 360) {
grid.x = ceil(float(param.q_length) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
if (block.x > 2048 && block.x <= 4096) {
LAUNCH_MAKSED_SOFTMAX(4)
}
else if (block.x > 1024) {
LAUNCH_MAKSED_SOFTMAX(2)
}
else if (block.x > 0) {
LAUNCH_MAKSED_SOFTMAX(1)
}
else {
FT_CHECK(param.k_length <= 4096);
}
}
template void invokeMaskedSoftmax(MaskedSoftmaxParam<float, float>& param, cudaStream_t stream);
template void invokeMaskedSoftmax(MaskedSoftmaxParam<half, float>& param, cudaStream_t stream);
template void invokeMaskedSoftmax(MaskedSoftmaxParam<half, half>& param, cudaStream_t stream);
#ifdef ENABLE_BF16
template<>
void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, float>& param, cudaStream_t stream)
{
// attention_score, (batch_size, head_num, q_length, k_length), softmax output.
// qk, (batch_size, head_num, q_length, k_length), QK^T.
// attention_mask, (batch_size, q_length, k_length), attention mask.
// linear_bias_slopes, (head_num,) the slopes of the linear position bias.
using T = __nv_bfloat16;
using T_IN = float;
dim3 grid(param.q_length, param.batch_size, param.num_heads);
if (param.batch_size * param.num_heads > 360) {
grid.x = ceil(float(param.q_length) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
if (block.x > 2048 && block.x <= 4096) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4);
}
else if (block.x > 1024) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 2);
}
else if (block.x > 0) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 1);
}
else {
FT_CHECK(param.k_length <= 4096);
}
}
template<>
void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, __nv_bfloat16>& param, cudaStream_t stream)
{
// attention_score, (batch_size, head_num, q_length, k_length), softmax output.
// qk, (batch_size, head_num, q_length, k_length), QK^T.
// attention_mask, (batch_size, q_length, k_length), attention mask.
// linear_bias_slopes, (head_num,) the slopes of the linear position bias.
using T = __nv_bfloat16;
using T_IN = __nv_bfloat16;
dim3 grid(param.q_length, param.batch_size, param.num_heads);
if (param.batch_size * param.num_heads > 360) {
grid.x = ceil(float(param.q_length) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
if (block.x > 2048 && block.x <= 4096) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4);
}
else if (block.x > 1024) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 2);
}
else if (block.x > 0) {
LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 1);
}
else {
FT_CHECK(param.k_length <= 4096);
}
}
#endif
#undef LAUNCH_MAKSED_SOFTMAX
#undef LAUNCH_MAKSED_SOFTMAX_
template<typename T>
__global__ void transpose(const T* src,
T* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
int int8_mode)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int batch_id = tid / (head_num * seq_len * size_per_head);
int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head);
int seq_id = (tid % (seq_len * size_per_head)) / size_per_head;
int id = tid % size_per_head;
int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head);
if (int8_mode == 2) {
using Int8_Packed_T = typename packed_as<int8_t, num_elems<T>::value>::type;
using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;
const Float_Packed_T scale_val = cuda_cast<Float_Packed_T>(*scale);
reinterpret_cast<Int8_Packed_T*>(dst)[target_id] =
cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src[tid]) * scale_val);
}
else {
dst[target_id] = src[tid];
}
}
template<>
__global__ void transpose(const float* src,
float* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
int int8_mode)
{
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
const int target_id = batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head
+ head_id * size_per_head + threadIdx.x;
const int src_id = blockIdx.x * size_per_head + threadIdx.x;
if (int8_mode == 2) {
const float scale_val = *scale;
reinterpret_cast<int8_t*>(dst)[target_id] = cuda_cast<int8_t>(src[src_id] * scale_val);
}
else {
dst[target_id] = src[src_id];
}
}
template<typename T>
void invokeTransposeQKV(T* dst,
T* src,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
dim3 grid, block;
if (sizeof(T) == 2) {
int seq_per_block = 1;
grid.x = batch_size * head_num * seq_len / seq_per_block;
while (seq_per_block < 4 && grid.x % 2 == 0) {
grid.x /= 2;
seq_per_block *= 2;
}
FT_CHECK(grid.x * seq_per_block == (size_t)batch_size * head_num * seq_len);
if (seq_per_block * size_per_head % 2 == 0) {
block.x = seq_per_block * size_per_head / 2;
if (std::is_same<T, half>::value) {
transpose<half2><<<grid, block, 0, stream>>>(
(half2*)src, (half2*)dst, batch_size, seq_len, head_num, size_per_head / 2, scale, int8_mode);
}
#ifdef ENABLE_BF16
else {
transpose<__nv_bfloat162><<<grid, block, 0, stream>>>((__nv_bfloat162*)src,
(__nv_bfloat162*)dst,
batch_size,
seq_len,
head_num,
size_per_head / 2,
scale,
int8_mode);
}
#endif
}
else {
block.x = seq_per_block * size_per_head;
transpose<T>
<<<grid, block, 0, stream>>>(src, dst, batch_size, seq_len, head_num, size_per_head, scale, int8_mode);
}
}
else {
const int seq_per_block = 1;
grid.x = batch_size * head_num * seq_len / seq_per_block;
block.x = seq_per_block * size_per_head;
transpose<T>
<<<grid, block, 0, stream>>>(src, dst, batch_size, seq_len, head_num, size_per_head, scale, int8_mode);
}
}
#define INSTANTIATETRANSPOSEQKV(T) \
template void invokeTransposeQKV(T* src, \
T* dst, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const float* scale, \
const int int8_mode, \
cudaStream_t stream)
INSTANTIATETRANSPOSEQKV(float);
INSTANTIATETRANSPOSEQKV(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEQKV(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEQKV
template<typename T>
__global__ void add_QKV_bias_rebuild_padding_ia3(const T* Q,
const T* bias_Q,
const T* K,
const T* bias_K,
const T* V,
const T* bias_V,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset)
{
const int bid = blockIdx.x;
const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int n = head_num * size_per_head;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
const int tgt_head_id = idx / size_per_head;
const int tgt_hidden_id = idx % size_per_head;
const int src_id = bid * n + idx;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head
+ tgt_seq_id * size_per_head + tgt_hidden_id;
q_buf_[tgt_id] = add(ldg(&Q[src_id]), ldg(&bias_Q[idx]));
T k = ldg(&K[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + idx];
}
k_buf_[tgt_id] = add(k, ldg(&bias_K[idx]));
T v = ldg(&V[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + idx];
}
v_buf_[tgt_id] = add(v, ldg(&bias_V[idx]));
}
}
template<typename T>
__global__ void rebuild_padding_ia3(const T* Q,
const T* K,
const T* V,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset)
{
const int bid = blockIdx.x;
const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len;
const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len;
const int n = head_num * size_per_head;
const bool use_ia3 = ia3_tasks != nullptr;
const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0;
const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr);
const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr);
for (int idx = threadIdx.x; idx < n; idx += blockDim.x) {
const int tgt_head_id = idx / size_per_head;
const int tgt_hidden_id = idx % size_per_head;
const int src_id = bid * n + idx;
const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head
+ tgt_seq_id * size_per_head + tgt_hidden_id;
q_buf_[tgt_id] = ldg(&Q[src_id]);
T k = ldg(&K[src_id]);
if (use_ia3_key) {
k = k * ia3_key_weights[ia3_task * n + idx];
}
k_buf_[tgt_id] = k;
T v = ldg(&V[src_id]);
if (use_ia3_value) {
v = v * ia3_value_weights[ia3_task * n + idx];
}
v_buf_[tgt_id] = v;
}
}
template<typename T>
void invokeAddQKVBiasIA3RebuildPadding(T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
T* q_buf,
T* k_buf,
T* v_buf,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int valid_word_num,
const int* mask_offset,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream)
{
#ifdef ENABLE_BF16
bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);
#else
bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);
#endif
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
int block_size = head_num * size_per_head;
if (is_half2) {
while (block_size > 512) {
if (block_size % 2 == 0) {
block_size /= 2;
}
else {
is_half2 = false;
block_size = std::min(block_size, 512);
break;
}
}
}
else {
block_size = std::min(block_size, 512);
}
if (bias_Q == nullptr && bias_K == nullptr && bias_V == nullptr) {
if (is_half2) {
rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>((T2*)Q,
(T2*)K,
(T2*)V,
(T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2,
mask_offset);
}
else {
rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>(Q,
K,
V,
q_buf,
k_buf,
v_buf,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head,
mask_offset);
}
}
else if (bias_Q != nullptr && bias_K != nullptr && bias_V != nullptr) {
if (is_half2) {
add_QKV_bias_rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>((T2*)Q,
(const T2*)bias_Q,
(T2*)K,
(const T2*)bias_K,
(T2*)V,
(const T2*)bias_V,
(T2*)q_buf,
(T2*)k_buf,
(T2*)v_buf,
ia3_tasks,
(const T2*)ia3_key_weights,
(const T2*)ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head / 2,
mask_offset);
}
else {
add_QKV_bias_rebuild_padding_ia3<<<valid_word_num, block_size, 0, stream>>>(Q,
bias_Q,
K,
bias_K,
V,
bias_V,
q_buf,
k_buf,
v_buf,
ia3_tasks,
ia3_key_weights,
ia3_value_weights,
batch_size,
seq_len,
head_num,
size_per_head,
mask_offset);
}
}
else {
FT_CHECK(false);
}
}
#define INSTANTIATEADDQKVBIASIA3REBUILDPADDING(T) \
template void invokeAddQKVBiasIA3RebuildPadding(T* Q, \
const T* bias_Q, \
T* K, \
const T* bias_K, \
T* V, \
const T* bias_V, \
T* q_buf, \
T* k_buf, \
T* v_buf, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int valid_word_num, \
const int* mask_offset, \
const int* ia3_tasks, \
const T* ia3_key_weights, \
const T* ia3_value_weights, \
cudaStream_t stream)
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(float);
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATEADDQKVBIASIA3REBUILDPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATEADDQKVBIASREBUILDPADDING
template<typename T>
__global__ void transpose_remove_padding(const T* src,
T* dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset,
const float* scale,
const int int8_mode)
{
// TODO: optimize this kernel?
// do remove_sequence_length_padding
const int bid = blockIdx.x; // batch * seq_len or valid_word_num
const int src_batch_id = (bid + mask_offset[bid]) / seq_len;
const int src_seq_id = (bid + mask_offset[bid]) % seq_len;
const int dst_seq_id = bid;
const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head;
const int dst_offset_base = dst_seq_id * head_num * size_per_head;
using Int8_Packed_T = typename packed_as<int8_t, num_elems<T>::value>::type;
using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;
const Float_Packed_T scale_val =
int8_mode == 2 ? cuda_cast<Float_Packed_T>(*scale) : cuda_cast<Float_Packed_T>(0.0f);
for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) {
const int head_id = idx / size_per_head;
const int hidden_id = idx % size_per_head;
const T src_elem = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]);
if (int8_mode == 2) {
reinterpret_cast<Int8_Packed_T*>(dst)[dst_offset_base + idx] =
cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src_elem) * scale_val);
}
else {
dst[dst_offset_base + idx] = src_elem;
}
}
}
// clang-format off
template<typename T>
void invokeTransposeAttentionOutRemovePadding(T* src,
T* dst,
const int valid_word_num,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
#ifdef ENABLE_BF16
bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);
#else
bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);
#endif
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
int block_size = head_num * size_per_head;
if (is_half2) {
while (block_size > 512) {
if (block_size % 2 == 0) {
block_size /= 2;
}
else {
is_half2 = false;
block_size = std::min(block_size, 1024);
break;
}
}
}
else {
block_size = std::min(block_size, 1024);
}
if (is_half2) {
transpose_remove_padding<T2><<<valid_word_num, block_size, 0, stream>>>(
(T2*)src, (T2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset, scale, int8_mode);
}
else {
transpose_remove_padding<<<valid_word_num, block_size, 0, stream>>>(
src, dst, batch_size, seq_len, head_num, size_per_head, mask_offset, scale, int8_mode);
}
}
// clang-format on
#define INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(T) \
template void invokeTransposeAttentionOutRemovePadding(T* src, \
T* dst, \
const int valid_word_num, \
const int batch_size, \
const int seq_len, \
const int head_num, \
const int size_per_head, \
const int* mask_offset, \
const float* scale, \
const int int8_mode, \
cudaStream_t stream)
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(float);
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING
template<typename T>
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
T* k_buf,
T* v_buf,
T* QKV,
const T* __restrict qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
const float* scale,
const int int8_mode)
{
// QKV: [token_num, 3, n]
// qkv_bias: [3, n]
// q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head]
T* qkv_ptr[3] = {q_buf, k_buf, v_buf};
const int n = head_num * size_per_head;
for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 3 * n;
index += gridDim.x * blockDim.x) {
const int bias_id = index % (3 * n);
const int token_idx = index / (3 * n);
const int token_padded_idx = token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
const int target_batch_id = token_padded_idx / seq_len;
const int seq_id = token_padded_idx % seq_len;
const int qkv_id = (index % (3 * n)) / n;
const int head_id = (index % n) / size_per_head;
const int size_id = index % size_per_head;
T val;
if (int8_mode == 2) {
val = cuda_cast<T>(cuda_cast<float>(reinterpret_cast<const int8_t*>(QKV)[index]) * scale[qkv_id]);
}
else {
val = ldg(&QKV[index]);
}
val = val + ldg(&qkv_bias[bias_id]);
if (int8_mode == 2) {
// TODO(mseznec): add support for int8 BMM with FusedAtt
}
else {
QKV[index] = val;
}
qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head
+ seq_id * size_per_head + size_id] = val;
}
}
template<typename T>
struct Vec_t {
static constexpr int size = 0;
};
template<>
struct Vec_t<float> {
using Type = float2;
static constexpr int size = 2;
};
template<>
struct Vec_t<half> {
using Type = uint32_t;
static constexpr int size = 2;
};
#ifdef ENABLE_BF16
template<>
struct Vec_t<__nv_bfloat16> {
using Type = __nv_bfloat162;
static constexpr int size = 2;
};
#endif
/// TODO: support batch step offset
template<typename T, bool PREFIX_PROMPT>
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam<T> param,
T* QKV,
const T* __restrict qkv_bias,
const int* padding_offset,
const int* history_length,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const bool neox_rotary_style)
{
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
// QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
// For q and k, also apply the rotary embedding.
// When we pass prefix prompt, this kernel also concatenate the prefix prompt and key/value along
// seq_len dimension like [prompt, key/value].
// So, the final shape of q is same ([batch_size, head_num, seq_len, size_per_head]), but
// the shapes of key and values become [batch_size, head_num, max_prefix_prompt_length + seq_len, size_per_head].
// NOTE: QKV src shape (batch_size, seq_len, 3, head_num, size_per_head)
// QKV dst shape (3, batch_size, head_num, seq_len, size_per_head)
extern __shared__ __align__(sizeof(float2)) char smem_[]; // align on largest vector type
constexpr int vec_size = Vec_t<T>::size;
using Vec_t = typename Vec_t<T>::Type;
const int token_idx = blockIdx.x - batch_size * param.max_prefix_prompt_length;
const int token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx];
const int tgt_token_idx = token_idx + token_padding_offset;
const int batch_idx = tgt_token_idx / seq_len;
const int seq_idx = tgt_token_idx % seq_len;
const int head_idx = blockIdx.y;
const int tidx = threadIdx.x;
const int total_seq_len = param.max_prefix_prompt_length + seq_len;
const bool is_masked = tidx * vec_size >= size_per_head;
// NOTE: blockIdx.x < batch_size * param.max_prefix_prompt_length really handles prefix prompts
if (PREFIX_PROMPT && token_idx < 0) {
const int prompt_batch_idx = blockIdx.x / param.max_prefix_prompt_length;
const int prompt_seq_idx = blockIdx.x % param.max_prefix_prompt_length;
const int prompt_length = param.d_prefix_prompt_lengths[prompt_batch_idx];
if (prompt_seq_idx < prompt_length) {
const int dest_kv_idx = prompt_batch_idx * size_per_head * total_seq_len * head_num
+ head_idx * size_per_head * total_seq_len + prompt_seq_idx * size_per_head
+ tidx * vec_size;
const int prefix_kv_idx =
size_per_head * prompt_length * head_idx + size_per_head * prompt_seq_idx + tidx * vec_size;
const T* prefix_prompt_k = param.d_prefix_prompt_batch[prompt_batch_idx]
+ param.prefix_prompt_layer_offset_per_seq * prompt_length;
const T* prefix_prompt_v = prefix_prompt_k + prompt_length * head_num * size_per_head;
if (!is_masked) {
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) =
*reinterpret_cast<const Vec_t*>(&prefix_prompt_k[prefix_kv_idx]);
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) =
*reinterpret_cast<const Vec_t*>(&prefix_prompt_v[prefix_kv_idx]);
}
}
return;
}
const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
const int n = head_num * size_per_head;
// the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len)
// and Q [0..seq_len)
// Note: if !PREFIX_PROMPT, max_pp_len = 0, so it's no-op
const int dst_kv_seq_idx = seq_idx + prefix_prompt_length;
// NOTE: q has seq len excluding prefix prompt
// src QKV: [batch, time, 3, head, hidden]
const int src_q_idx = token_idx * 3 * n + hidden_idx;
const int src_k_idx = token_idx * 3 * n + hidden_idx + n;
const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n;
Vec_t q, k, v;
Vec_t q_bias, k_bias, v_bias;
if (!is_masked) {
q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
if (qkv_bias) {
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + n]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + 2 * n]);
}
}
if (qkv_bias) {
q = mmha::add(q, q_bias);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
}
const int t_offset = history_length ? history_length[batch_idx] : 0;
if (!neox_rotary_style) {
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx + t_offset);
}
else {
const bool do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
T* q_smem = reinterpret_cast<T*>(smem_);
T* k_smem = q_smem + rotary_embedding_dim;
const int half_rotary_dim = rotary_embedding_dim / 2;
const int half_idx = (tidx * vec_size) / half_rotary_dim;
const int intra_half_idx = (tidx * vec_size) % half_rotary_dim;
const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts?
if (do_rotary) {
*reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
*reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
}
__syncthreads();
const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
constexpr int tidx_factor = vec_size / 2;
if (do_rotary) {
mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, rotary_embedding_dim, dst_kv_seq_idx + t_offset);
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
}
__syncthreads();
if (do_rotary) {
q = *reinterpret_cast<Vec_t*>(q_smem + half_idx * smem_pitch + intra_half_idx);
k = *reinterpret_cast<Vec_t*>(k_smem + half_idx * smem_pitch + intra_half_idx);
}
}
if (!is_masked && !q_buf) { // also skip modifing QKV if q/k/v_buf are present
*reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
*reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
}
const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len
+ seq_idx * size_per_head + tidx * vec_size;
const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * head_num
+ head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head
+ tidx * vec_size;
if (!is_masked) {
*reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
*reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k;
*reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v;
}
}
#define FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, PREFIX_PROMPT) \
add_fusedQKV_bias_transpose_kernel<T, PREFIX_PROMPT><<<grid, block, smem_size, stream>>>(q_buf, \
k_buf, \
v_buf, \
param, \
QKV, \
qkv_bias, \
padding_offset, \
history_length, \
batch_size, \
seq_len, \
head_num, \
size_per_head, \
rotary_embedding_dim, \
neox_rotary_style);
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam<T> param,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int* history_length,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
// [bs, seq_len, 3, head, Dh]
if (rotary_embedding_dim == 0 && param.max_prefix_prompt_length == 0) {
const int m = token_num;
const int n = head_num * size_per_head;
dim3 block(384);
dim3 grid((int)(ceil(1.0 * m * n / 384)));
add_fusedQKV_bias_transpose_kernel<<<grid, block, 0, stream>>>(q_buf,
k_buf,
v_buf,
QKV,
qkv_bias,
padding_offset,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
scale,
int8_mode);
}
else {
FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec)
// To implement rotary embeddings, each thread processes two QKV elems:
dim3 block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num);
size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0;
// NOTE: add offset for rotary embedding
// add_fusedQKV_bias_transpose_kernel<<<grid, block, 0, stream>>>(
// q_buf, k_buf, v_buf, param, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head,
// rotary_embedding_dim);
if (param.max_prefix_prompt_length == 0) {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
}
else {
FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true);
}
}
}
#define INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(T) \
template void invokeAddFusedQKVBiasTranspose(T* q_buf, \
T* k_buf, \
T* v_buf, \
PrefixPromptBatchWeightsParam<T> param, \
T* QKV, \
const T* qkv_bias, \
const int* padding_offset, \
const int* history_length, \
const int batch_size, \
const int seq_len, \
const int token_num, \
const int head_num, \
const int size_per_head, \
const int rotary_embedding_dim, \
const int neox_rotary_style, \
const float* scale, \
const int int8_mode, \
cudaStream_t stream)
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(float);
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(half);
#ifdef ENABLE_BF16
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(__nv_bfloat16);
#endif
#undef INSTANTIATEADDFUSEDQKVBIASTRANSPOSE
template<typename T>
__global__ void transpose_4d(T* dst,
T* src,
const int dim0,
const int dim1,
const int dim2,
const int dim3,
const int dim0_leading_dim,
const int ite)
{
// transpose from [dim0, dim1, dim2, dim3] to [dim2, X, dim1, dim3]
// where the dimension of X is dim0_leading_dim, and offset is ite * dim0
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * dim3; i += blockDim.x * gridDim.x) {
int index = i;
const int d3 = index % dim3;
index = (index - d3) / dim3;
const int d2 = index % dim2;
index = (index - d2) / dim2;
const int d1 = index % dim1;
index = (index - d1) / dim1;
const int d0 = index % dim0;
index = (index - d0) / dim0;
dst[d2 * dim0_leading_dim * dim1 * dim3 + (d0 + dim0 * ite) * dim1 * dim3 + d1 * dim3 + d3] = src[i];
}
}
template<>
__global__ void transpose_4d(half* dst,
half* src,
const int dim0,
const int dim1,
const int dim2,
const int dim3,
const int dim0_leading_dim,
const int ite)
{
half2* dst_ptr = (half2*)dst;
half2* src_ptr = (half2*)src;
const int half_dim3 = dim3 / 2;
// transpose from [dim0, dim1, dim2, half_dim3] to [dim2, dim0, dim1, half_dim3]
// where the dimension of X is dim0_leading_dim, and offset is ite * dim0
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * half_dim3;
i += blockDim.x * gridDim.x) {
int index = i;
const int d3 = index % half_dim3;
index = (index - d3) / half_dim3;
const int d2 = index % dim2;
index = (index - d2) / dim2;
const int d1 = index % dim1;
index = (index - d1) / dim1;
const int d0 = index % dim0;
index = (index - d0) / dim0;
dst_ptr[d2 * dim0_leading_dim * dim1 * half_dim3 + (d0 + dim0 * ite) * dim1 * half_dim3 + d1 * half_dim3 + d3] =
src_ptr[i];
}
}
template<typename T>
void invokeTranspose4d(T* dst,
T* src,
const int local_batch_size,
const int seq_len,
const int size_per_head,
const int local_hidden_units,
const int local_head_num,
const int batch_size,
const int ite,
cudaStream_t stream)
{
transpose_4d<<<local_batch_size * seq_len * local_hidden_units / 512, 512 / (4 / (sizeof(T))), 0, stream>>>(
dst, src, local_batch_size, local_head_num, seq_len, size_per_head, batch_size, ite);
}
#define INSTANTIATETRANSPOSE4D(T) \
template void invokeTranspose4d(T* dst, \
T* src, \
const int local_batch_size, \
const int seq_len, \
const int size_per_head, \
const int local_hidden_units, \
const int local_head_num, \
const int batch_size, \
const int ite, \
cudaStream_t stream)
INSTANTIATETRANSPOSE4D(float);
INSTANTIATETRANSPOSE4D(half);
#undef INSTANTIATETRANSPOSE4D
template<typename T>
__global__ void transpose_4d_batch_major_k_cache(
T* k_dst, const T* k_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len)
{
const int batch_id = blockIdx.y;
const int head_id = blockIdx.z;
constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8;
auto key_src = reinterpret_cast<const uint4*>(k_src + batch_id * head_num * size_per_head * seq_len
+ head_id * size_per_head * seq_len);
auto key_dst = reinterpret_cast<uint4*>(k_dst + batch_id * head_num * size_per_head * max_seq_len
+ head_id * size_per_head * max_seq_len);
const int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
int size_per_head_div_x = size_per_head / X_ELEMS;
if (out_idx >= size_per_head_div_x * max_seq_len) {
return;
}
int idx = out_idx;
const int k_seq_len_id = idx % max_seq_len;
idx = (idx - k_seq_len_id) / max_seq_len;
const int k_head_size_id = idx % size_per_head_div_x;
if (k_seq_len_id < seq_len) {
key_dst[out_idx] = key_src[k_seq_len_id * size_per_head_div_x + k_head_size_id];
}
}
template<typename T>
__global__ void transpose_4d_batch_major_v_cache(
T* v_dst, const T* v_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len)
{
const int batch_id = blockIdx.y;
const int head_id = blockIdx.z;
// 16 byte loads will handle "x" dimension
auto val_src = reinterpret_cast<const uint4*>(v_src + batch_id * head_num * size_per_head * seq_len
+ head_id * size_per_head * seq_len);
auto val_dst = reinterpret_cast<uint4*>(v_dst + batch_id * head_num * size_per_head * max_seq_len
+ head_id * size_per_head * max_seq_len);
// idx is over output dimension L * size_per_head / x for values
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8;
const int size_per_head_div_x = size_per_head / X_ELEMS;
if (idx >= size_per_head_div_x * seq_len) {
return;
}
val_dst[idx] = val_src[idx];
}
template<typename T>
void invokeTranspose4dBatchMajor(T* k_dst,
T* v_dst,
const T* k_src,
const T* v_src,
const int local_batch_size,
const int seq_len,
const int max_seq_len,
const int size_per_head,
const int local_head_num,
cudaStream_t stream)
{
constexpr int block_sz = 128;
constexpr int x = (sizeof(T) == 4) ? 4 : 8;
int size = max_seq_len * size_per_head / x;
dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num);
dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num);
transpose_4d_batch_major_k_cache<<<grid, block_sz, 0, stream>>>(
k_dst, k_src, local_head_num, size_per_head, seq_len, max_seq_len);
transpose_4d_batch_major_v_cache<<<grid_v, block_sz, 0, stream>>>(
v_dst, v_src, local_head_num, size_per_head, seq_len, max_seq_len);
}
#define INSTANTIATETRANSPOSE4DBATCHMAJOR(T) \
template void invokeTranspose4dBatchMajor(T* k_dst, \
T* v_dst, \
const T* k_src, \
const T* v_src, \
const int local_batch_size, \
const int seq_len, \
const int max_seq_len, \
const int size_per_head, \
const int local_head_num, \
cudaStream_t stream)
INSTANTIATETRANSPOSE4DBATCHMAJOR(float);
INSTANTIATETRANSPOSE4DBATCHMAJOR(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSE4DBATCHMAJOR(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSE4DBATCHMAJOR
template<typename T>
__global__ void addRelativeAttentionBias(
T* qk_buf, const T* relative_attention_bias, const int batch_size, const int head_num, const int seq_len)
{
for (int i = threadIdx.x; i < batch_size * seq_len; i += blockDim.x) {
int batch_id = i / seq_len;
int seq_id = i % seq_len;
const int bias_index = blockIdx.x * seq_len + seq_id;
const int qk_index = batch_id * gridDim.x * seq_len + bias_index;
qk_buf[qk_index] = add(qk_buf[qk_index], relative_attention_bias[bias_index]);
}
}
template<typename T>
void invokeAddRelativeAttentionBias(T* qk_buf,
const T* relative_attention_bias,
const int batch_size,
const int head_num,
const int seq_len,
cudaStream_t stream)
{
// qk_buf: [batch_size, head_num, seq_len, seq_len]
// relative_attention_bias: [1, head_num, seq_len, seq_len]
dim3 grid(head_num * seq_len);
dim3 block(512);
using T2 = typename TypeConverter<T>::Type;
#ifdef ENABLE_BF16
const bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (seq_len % 2 == 0);
#else
const bool is_half2 = (std::is_same<T, half>::value) && (seq_len % 2 == 0);
#endif
if (is_half2) {
addRelativeAttentionBias<T2><<<grid, block, 0, stream>>>(
(T2*)qk_buf, (const T2*)relative_attention_bias, batch_size, head_num, seq_len / 2);
}
else {
addRelativeAttentionBias<<<grid, block, 0, stream>>>(
qk_buf, relative_attention_bias, batch_size, head_num, seq_len);
}
}
#define INSTANTIATEADDRELATIVEATTENTIONBIAS(T) \
template void invokeAddRelativeAttentionBias(T* qk_buf, \
const T* relative_attention_bias, \
const int batch_size, \
const int head_num, \
const int seq_len, \
cudaStream_t stream)
INSTANTIATEADDRELATIVEATTENTIONBIAS(float);
INSTANTIATEADDRELATIVEATTENTIONBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDRELATIVEATTENTIONBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDRELATIVEATTENTIONBIAS
/******************* invokeAddHead3SizeQKVBias ***********************/
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, 3*batch);
// block(num_head * size_per_head)
template<typename T>
__global__ void add_head3Size_QKV_bias(const T* mm_qkv,
const T* bias_qkv,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
T* buf_ptr;
int qkv_id = blockIdx.z / batch;
if (qkv_id == 0) {
buf_ptr = q_buf_;
}
else if (qkv_id == 1) {
buf_ptr = k_buf_;
}
else {
buf_ptr = v_buf_;
}
const int batch_id = blockIdx.z % batch;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
const T bias = ldg(bias_qkv + bias_idx);
const int input_idx =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx;
T tmp = mm_qkv[input_idx] + bias;
int target_id = (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
;
buf_ptr[target_id] = tmp;
}
// for float2, size_per_head /= 2
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, 3*batch);
// block(num_head * size_per_head)
template<>
__global__ void add_head3Size_QKV_bias(const float2* mm_qkv,
const float2* bias_qkv,
float2* q_buf_,
float2* k_buf_,
float2* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
float2* buf_ptr;
int qkv_id = blockIdx.z / batch;
if (qkv_id == 0) {
buf_ptr = q_buf_;
}
else if (qkv_id == 1) {
buf_ptr = k_buf_;
}
else {
buf_ptr = v_buf_;
}
const int batch_id = blockIdx.z % batch;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
const float2 bias = ldg(bias_qkv + bias_idx);
const int input_idx =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx;
float2 tmp = mm_qkv[input_idx];
tmp.x += bias.x;
tmp.y += bias.y;
int target_id = (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
;
buf_ptr[target_id] = tmp;
}
// for half2, size_per_head /= 2
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, batch);
// block(num_head * size_per_head)
template<>
__global__ void add_head3Size_QKV_bias(const half2* mm_qkv,
const half2* bias_qkv,
half2* q_buf_,
half2* k_buf_,
half2* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
const int batch_id = blockIdx.z;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int input_offset =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head;
const int target_id =
(((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
int qkv_id = 0;
int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
half2 bias = __ldg(bias_qkv + bias_idx);
int input_idx = input_offset + bias_idx;
half2 tmp = mm_qkv[input_idx];
tmp = __hadd2(tmp, bias);
q_buf_[target_id] = tmp;
qkv_id = 1;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = __ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = __hadd2(tmp, bias);
k_buf_[target_id] = tmp;
qkv_id = 2;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = __ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = __hadd2(tmp, bias);
v_buf_[target_id] = tmp;
}
#ifdef ENABLE_BF16
template<>
__global__ void add_head3Size_QKV_bias(const __nv_bfloat162* mm_qkv,
const __nv_bfloat162* bias_qkv,
__nv_bfloat162* q_buf_,
__nv_bfloat162* k_buf_,
__nv_bfloat162* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head)
{
const int batch_id = blockIdx.z;
const int token_id = blockIdx.x;
const int window_id = blockIdx.y;
const int head_id = threadIdx.x / size_per_head;
const int id_in_head = threadIdx.x % size_per_head;
const int input_offset =
((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head;
const int target_id =
(((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
+ id_in_head;
int qkv_id = 0;
int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
__nv_bfloat162 bias = ldg(bias_qkv + bias_idx);
int input_idx = input_offset + bias_idx;
__nv_bfloat162 tmp = mm_qkv[input_idx];
tmp = bf16hadd2(tmp, bias);
q_buf_[target_id] = tmp;
qkv_id = 1;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = bf16hadd2(tmp, bias);
k_buf_[target_id] = tmp;
qkv_id = 2;
bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
bias = ldg(bias_qkv + bias_idx);
input_idx = input_offset + bias_idx;
tmp = mm_qkv[input_idx];
tmp = bf16hadd2(tmp, bias);
v_buf_[target_id] = tmp;
}
#endif
template<typename T>
void invokeAddHead3SizeQKVBias(const T* mm_qkv,
const T* bias_qkv,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int num_head,
const int size_per_head,
cudaStream_t stream)
{
if (std::is_same<T, float>::value) {
dim3 grid(window_len, window_num, 3 * batch);
dim3 block(num_head * size_per_head);
if (block.x < 1024) {
add_head3Size_QKV_bias<<<grid, block, 0, stream>>>(
mm_qkv, bias_qkv, q_buf_, k_buf_, v_buf_, batch, window_num, window_len, num_head, size_per_head);
}
else if ((block.x % 2 == 0) && (block.x / 2 < 1024)) {
block.x /= 2;
add_head3Size_QKV_bias<<<grid, block, 0, stream>>>((const float2*)mm_qkv,
(const float2*)bias_qkv,
(float2*)q_buf_,
(float2*)k_buf_,
(float2*)v_buf_,
batch,
window_num,
window_len,
num_head,
size_per_head / 2);
}
else {
printf("[ERROR][invokeAddHead3SizeQKVBias] unsupported block.x!\n");
exit(-1);
}
}
#ifdef ENABLE_BF16
else if (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) {
#else
else if (std::is_same<T, half>::value) {
#endif
dim3 grid(window_len, window_num, batch);
dim3 block(num_head * size_per_head / 2);
using T2 = typename TypeConverter<T>::Type; // half2 or bfloat16
if (block.x > 1024) {
printf("[ERROR][invokeAddHead3SizeQKVBias] block.x > 1024!\n");
exit(-1);
}
add_head3Size_QKV_bias<<<grid, block, 0, stream>>>((const T2*)mm_qkv,
(const T2*)bias_qkv,
(T2*)q_buf_,
(T2*)k_buf_,
(T2*)v_buf_,
batch,
window_num,
window_len,
num_head,
size_per_head / 2);
}
}
#define INSTANTIATEADDHEAD3SIZEQKVBIAS(T) \
template void invokeAddHead3SizeQKVBias<T>(const T* mm_qkv, \
const T* bias_qkv, \
T* q_buf_, \
T* k_buf_, \
T* v_buf_, \
const int batch, \
const int window_num, \
const int window_len, \
const int num_head, \
const int size_per_head, \
cudaStream_t stream)
INSTANTIATEADDHEAD3SIZEQKVBIAS(float);
INSTANTIATEADDHEAD3SIZEQKVBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDHEAD3SIZEQKVBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDHEAD3SIZEQKVBIAS
/******************* invokeMaskedSoftMaxWithRelPosBias ***********************/
// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T>
__global__ void softmax_withRelPosBias_element1_kernel(T* qk_buf,
const T* attn_mask,
const T* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const int window_len_x_window_len,
const float qk_scale)
{
bool qual = threadIdx.x < window_len;
for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
float tmp = -1e20f;
__shared__ float s_mean, s_max;
int qk_offset;
if (qual) {
const int offset_in_window = window_id * window_len + threadIdx.x;
qk_offset = (blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window;
const int relative_pos_bias_offset = (blockIdx.y % num_head) * window_len_x_window_len + offset_in_window;
float mask_val =
(attn_mask == nullptr) ?
0.0f :
static_cast<float>(
ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window)));
tmp = qk_scale * static_cast<float>(qk_buf[qk_offset]) + mask_val
+ static_cast<float>(ldg(relative_pos_bias + relative_pos_bias_offset));
}
float max_val = blockReduceMax<float>(tmp);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float qk_tmp = qual ? __expf(tmp - s_max) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
if (qual) {
qk_buf[qk_offset] = (T)(qk_tmp * s_mean);
}
}
}
// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len/2 + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T2, typename T>
__global__ void softmax_withRelPosBias_element2_kernel(T2* qk_buf,
const T2* attn_mask,
const T2* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const int window_len_x_window_len,
const float qk_scale)
{
const int window_len_2 = window_len / 2;
const int tidx = threadIdx.x;
bool qual = tidx < window_len_2;
const T2 zero = {T(0.0f), T(0.0f)};
const int bdim = blockDim.x;
for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
float tmp = -1e20f;
__shared__ float s_mean, s_max;
int qk_offset;
float2 local_qk_val;
T2 qk_val;
if (qual) {
const int offset_in_window = window_id * window_len + 2 * tidx;
qk_offset = ((blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window) / 2;
const int relative_pos_bias_offset =
((blockIdx.y % num_head) * window_len_x_window_len + offset_in_window) / 2;
T2 mask_val =
(attn_mask == nullptr) ?
zero :
ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window) / 2);
qk_val = qk_buf[qk_offset];
local_qk_val.x = static_cast<float>(qk_val.x);
local_qk_val.y = static_cast<float>(qk_val.y);
const T2 bias_val = ldg(relative_pos_bias + relative_pos_bias_offset);
local_qk_val.x =
qk_scale * local_qk_val.x + static_cast<float>(mask_val.x) + static_cast<float>(bias_val.x);
local_qk_val.y =
qk_scale * local_qk_val.y + static_cast<float>(mask_val.y) + static_cast<float>(bias_val.y);
tmp = local_qk_val.x > local_qk_val.y ? local_qk_val.x : local_qk_val.y;
}
float max_val = bdim <= 32 ? warpReduceMax<float>(tmp) : blockReduceMax<float>(tmp);
if (tidx == 0) {
s_max = max_val;
}
__syncthreads();
local_qk_val.x = qual ? __expf(local_qk_val.x - s_max) : 0.0f;
local_qk_val.y = qual ? __expf(local_qk_val.y - s_max) : 0.0f;
float sum_val = bdim <= 32 ? warpReduceSum<float>(local_qk_val.x + local_qk_val.y) :
blockReduceSum<float>(local_qk_val.x + local_qk_val.y);
if (tidx == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
if (qual) {
local_qk_val.x = local_qk_val.x * s_mean;
local_qk_val.y = local_qk_val.y * s_mean;
qk_val.x = T(local_qk_val.x);
qk_val.y = T(local_qk_val.y);
qk_buf[qk_offset] = qk_val;
}
}
}
// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len/4 + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T4, typename T>
__global__ void softmax_withRelPosBias_element4_kernel(T4* qk_buf,
const T4* attn_mask,
const T4* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const int window_len_x_window_len,
const float qk_scale)
{
const int window_len_4 = window_len / 4;
const int tidx = threadIdx.x;
bool qual = tidx < window_len_4;
const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)};
const int bdim = blockDim.x;
for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
float tmp = -1e20f;
__shared__ float s_mean, s_max;
int qk_offset;
float4 local_qk_val;
T4 qk_val;
if (qual) {
const int offset_in_window = window_id * window_len + 4 * tidx;
qk_offset = ((blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window) / 4;
const int relative_pos_bias_offset =
((blockIdx.y % num_head) * window_len_x_window_len + offset_in_window) / 4;
T4 mask_val = (attn_mask == nullptr) ?
zero :
attn_mask[((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window) / 4];
qk_val = qk_buf[qk_offset];
local_qk_val.x = static_cast<float>(qk_val.x);
local_qk_val.y = static_cast<float>(qk_val.y);
local_qk_val.z = static_cast<float>(qk_val.z);
local_qk_val.w = static_cast<float>(qk_val.w);
const T4 bias_val = relative_pos_bias[relative_pos_bias_offset];
local_qk_val.x =
qk_scale * local_qk_val.x + static_cast<float>(mask_val.x) + static_cast<float>(bias_val.x);
local_qk_val.y =
qk_scale * local_qk_val.y + static_cast<float>(mask_val.y) + static_cast<float>(bias_val.y);
local_qk_val.z =
qk_scale * local_qk_val.z + static_cast<float>(mask_val.z) + static_cast<float>(bias_val.z);
local_qk_val.w =
qk_scale * local_qk_val.w + static_cast<float>(mask_val.w) + static_cast<float>(bias_val.w);
tmp = local_qk_val.x > local_qk_val.y ? local_qk_val.x : local_qk_val.y;
tmp = tmp > local_qk_val.z ? tmp : local_qk_val.z;
tmp = tmp > local_qk_val.w ? tmp : local_qk_val.w;
}
float max_val = bdim <= 32 ? warpReduceMax<float>(tmp) : blockReduceMax<float>(tmp);
if (tidx == 0) {
s_max = max_val;
}
__syncthreads();
local_qk_val.x = qual ? __expf(local_qk_val.x - s_max) : 0.0f;
local_qk_val.y = qual ? __expf(local_qk_val.y - s_max) : 0.0f;
local_qk_val.z = qual ? __expf(local_qk_val.z - s_max) : 0.0f;
local_qk_val.w = qual ? __expf(local_qk_val.w - s_max) : 0.0f;
float sum_val = bdim <= 32 ?
warpReduceSum<float>(local_qk_val.x + local_qk_val.y + local_qk_val.z + local_qk_val.w) :
blockReduceSum<float>(local_qk_val.x + local_qk_val.y + local_qk_val.z + local_qk_val.w);
if (tidx == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
if (qual) {
local_qk_val.x = local_qk_val.x * s_mean;
local_qk_val.y = local_qk_val.y * s_mean;
local_qk_val.z = local_qk_val.z * s_mean;
local_qk_val.w = local_qk_val.w * s_mean;
qk_val.x = T(local_qk_val.x);
qk_val.y = T(local_qk_val.y);
qk_val.z = T(local_qk_val.z);
qk_val.w = T(local_qk_val.w);
qk_buf[qk_offset] = qk_val;
}
}
}
template<typename T>
void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf,
const T* attn_mask,
const T* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
float qk_scale,
cudaStream_t stream)
{
const int word_per_thread = 1;
dim3 grid((window_len + word_per_thread - 1) / word_per_thread, window_num * num_head, batch_size);
if ((window_len % 4 == 0) && window_len / 4 >= 32) {
dim3 block((window_len / 4 + 31) / 32 * 32);
if (std::is_same<T, float>::value) {
softmax_withRelPosBias_element4_kernel<float4, float>
<<<grid, block, 0, stream>>>((float4*)qk_buf,
(const float4*)attn_mask,
(const float4*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
else if (std::is_same<T, half>::value) {
softmax_withRelPosBias_element4_kernel<half4, half>
<<<grid, block, 0, stream>>>((half4*)qk_buf,
(const half4*)attn_mask,
(const half4*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#ifdef ENABLE_BF16
else {
dim3 block((window_len + 31) / 32 * 32);
softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
attn_mask,
relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#endif
}
else if (window_len % 2 == 0) {
dim3 block((window_len / 2 + 31) / 32 * 32);
if (std::is_same<T, float>::value) {
softmax_withRelPosBias_element2_kernel<float2, float>
<<<grid, block, 0, stream>>>((float2*)qk_buf,
(const float2*)attn_mask,
(const float2*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
else if (std::is_same<T, half>::value) {
softmax_withRelPosBias_element2_kernel<half2, half>
<<<grid, block, 0, stream>>>((half2*)qk_buf,
(const half2*)attn_mask,
(const half2*)relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#ifdef ENABLE_BF16
else {
dim3 block((window_len + 31) / 32 * 32);
softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
attn_mask,
relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
#endif
}
else {
dim3 block((window_len + 31) / 32 * 32);
softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
attn_mask,
relative_pos_bias,
batch_size,
num_head,
window_num,
window_len,
window_len * window_len,
qk_scale);
}
}
#define INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(T) \
template void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, \
const T* attn_mask, \
const T* relative_pos_bias, \
const int batch_size, \
const int num_head, \
const int window_num, \
const int window_len, \
const float qk_scale, \
cudaStream_t stream)
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(float);
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS
template<typename T>
__global__ void transpose_attentions(
T* attentions_out, const T* attentions_in, size_t batch_size, size_t num_layers, size_t num_heads, size_t seq_len)
{
// attentions_in shape [B, H, S, S]
// attentions_out shape [B, L, H, S, S].
// Note that we write the L dimension as if it was index 0.
// In reality, the pointer has already been shifted to point to the correct layer.
const auto batch_idx = blockIdx.x;
const auto head_idx = blockIdx.y;
const auto dst_offset = (batch_idx * num_layers * num_heads + head_idx) * seq_len * seq_len;
const auto src_offset = (batch_idx * num_heads + head_idx) * seq_len * seq_len;
for (auto x = threadIdx.x; x < seq_len * seq_len; x += blockDim.x) {
attentions_out[dst_offset + x] = attentions_in[src_offset + x];
}
}
template<typename T>
void invokeTransposeAttentions(Tensor& attentions_out, const Tensor& attentions_in, cudaStream_t stream)
{
const size_t batch_size = attentions_in.shape[0];
const size_t num_heads = attentions_in.shape[1];
const size_t seq_len = attentions_in.shape[2];
const size_t num_layers = attentions_out.shape[1];
const dim3 gridSize(batch_size, num_heads);
const dim3 blockSize(512);
transpose_attentions<<<gridSize, blockSize, 0, stream>>>(
attentions_out.getPtr<T>(), attentions_in.getPtr<const T>(), batch_size, num_layers, num_heads, seq_len);
}
#define INSTANTIATETRANSPOSEATTENTIONS(T) \
template void invokeTransposeAttentions<T>( \
Tensor & attentions_out, const Tensor& attentions_in, cudaStream_t stream)
INSTANTIATETRANSPOSEATTENTIONS(float);
INSTANTIATETRANSPOSEATTENTIONS(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEATTENTIONS(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEATTENTIONS
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/utils/Tensor.h"
namespace fastertransformer {
template<typename T>
void invokeAddQKVBiasIA3Transpose(T* q_buf,
T* k_buf,
T* v_buf,
T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream);
template<typename T, typename T_IN>
struct MaskedSoftmaxParam {
// Common parameters.
T* attention_score = nullptr; // (batch_size, head_num, q_length, k_length)
const T_IN* qk = nullptr; // (batch_size, head_num, q_length, k_length)
const T* attention_mask = nullptr; // (batch_size, q_length, k_length)
int batch_size = 0;
int q_length = 0;
int k_length = 0;
int num_heads = 0;
T qk_scale = T(0.0f);
// Optional parameters that depend on the type of attention.
// The slopes of the linear position bias of ALiBi.
const T* linear_bias_slopes = nullptr; // (head_num,), optional
};
template<typename T, typename T_IN>
void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream);
template<typename T>
void invokeTransposeQKV(T* dst,
T* src,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const float* scale,
const int int8_mode,
cudaStream_t stream);
template<typename T>
void invokeAddQKVBiasIA3RebuildPadding(T* Q,
const T* bias_Q,
T* K,
const T* bias_K,
T* V,
const T* bias_V,
T* q_buf,
T* k_buf,
T* v_buf,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int valid_word_num,
const int* mask_offset,
const int* ia3_tasks,
const T* ia3_key_weights,
const T* ia3_value_weights,
cudaStream_t stream);
template<typename T>
void invokeTransposeAttentionOutRemovePadding(T* src,
T* dst,
const int valid_word_num,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int* mask_offset,
const float* scale,
const int int8_mode,
cudaStream_t stream);
// Prefix Prompt Parameters
template<typename T>
struct PrefixPromptBatchWeightsParam {
const T** d_prefix_prompt_batch = nullptr;
const int* d_prefix_prompt_lengths = nullptr;
const int max_prefix_prompt_length = 0;
// l * 2 * hidden_units_ / tensor_para_.world_size_
const size_t prefix_prompt_layer_offset_per_seq = 0;
};
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam<T> param,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int* history_length,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream);
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
PrefixPromptBatchWeightsParam<T> param,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
const float* scale,
const int int8_mode,
cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf,
k_buf,
v_buf,
param,
QKV,
qkv_bias,
padding_offset,
nullptr,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
rotary_embedding_dim,
neox_rotary_style,
scale,
int8_mode,
stream);
}
template<typename T>
void invokeAddFusedQKVBiasTranspose(T* q_buf,
T* k_buf,
T* v_buf,
T* QKV,
const T* qkv_bias,
const int* padding_offset,
const int batch_size,
const int seq_len,
const int token_num,
const int head_num,
const int size_per_head,
cudaStream_t stream)
{
invokeAddFusedQKVBiasTranspose(q_buf,
k_buf,
v_buf,
PrefixPromptBatchWeightsParam<T>{},
QKV,
qkv_bias,
padding_offset,
batch_size,
seq_len,
token_num,
head_num,
size_per_head,
0,
false,
(float*)nullptr,
0,
stream);
}
template<typename T>
void invokeTranspose4d(T* dst,
T* src,
const int local_batch_size,
const int seq_len,
const int size_per_head,
const int local_hidden_units,
const int local_head_num,
const int batch_size,
const int ite,
cudaStream_t stream);
template<typename T>
void invokeTranspose4dBatchMajor(T* k_dst,
T* v_dst,
const T* k_src,
const T* v_src,
const int local_batch_size,
const int seq_len,
const int max_seq_len,
const int size_per_head,
const int local_head_num,
cudaStream_t stream);
template<typename T>
void invokeAddRelativeAttentionBias(T* qk_buf,
const T* relative_attention_bias,
const int batch_size,
const int head_num,
const int seq_len,
cudaStream_t stream);
template<typename T>
void invokeAddHead3SizeQKVBias(const T* mm_qkv,
const T* bias_qkv,
T* q_buf_,
T* k_buf_,
T* v_buf_,
const int batch,
const int window_num,
const int window_len,
const int head_num,
const int size_per_head,
cudaStream_t stream);
template<typename T>
void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf,
const T* attn_mask,
const T* relative_pos_bias,
const int batch_size,
const int num_head,
const int window_num,
const int window_len,
const float qk_scale,
cudaStream_t stream);
template<typename T>
void invokeTransposeAttentions(Tensor& attentions_out, const Tensor& attentions_in, cudaStream_t stream = 0);
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include "src/fastertransformer/utils/Tensor.h"
#include "src/fastertransformer/utils/allocator.h"
#include "src/fastertransformer/utils/cublasMMWrapper.h"
namespace fastertransformer {
class BaseLayer {
public:
BaseLayer(cudaStream_t stream,
cublasMMWrapper* cublas_wrapper,
IAllocator* allocator,
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop = nullptr,
bool sparse = false):
stream_(stream),
cublas_wrapper_(cublas_wrapper),
allocator_(allocator),
cuda_device_prop_(cuda_device_prop),
is_free_buffer_after_forward_(is_free_buffer_after_forward),
sparse_(sparse){};
virtual ~BaseLayer() = default;
virtual cudaStream_t getStream()
{
return stream_;
}
virtual void setStream(cudaStream_t stream)
{
stream_ = stream;
}
protected:
virtual void allocateBuffer() = 0;
virtual void freeBuffer() = 0;
// device environments
cudaStream_t stream_;
cublasMMWrapper* cublas_wrapper_;
IAllocator* allocator_;
cudaDeviceProp* cuda_device_prop_ = nullptr;
bool is_free_buffer_after_forward_;
bool is_allocate_buffer_ = false; // TODO (bhsueh) to be deprecated
bool sparse_;
};
} // namespace fastertransformer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment