Unverified Commit 9efcac38 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

check-in fastertransformer (#7)

* add ft code

* gitignore

* fix lint

* revert fmha
parent 720fc533
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/activation_kernels.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "src/fastertransformer/utils/memory_utils.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif
namespace fastertransformer {
/* Gelu Activation */
__forceinline__ __device__ float copysignf_pos(float a, float b)
{
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__inline__ __device__ float tanh_opt(float x)
{
#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)
float r;
asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
return r;
#else
const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif
}
template<typename T>
struct GeluActivation {
using return_type = T;
static __device__ __forceinline__ T apply(const T& val)
{
const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
}
};
template<>
struct GeluActivation<half2> {
using return_type = half2;
static __device__ __forceinline__ half2 apply(const half2& val)
{
half2 val_pow3 = __hmul2(val, __hmul2(val, val));
float2 tmp_pow = __half22float2(val_pow3);
float2 tmp = __half22float2(val);
tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return __hmul2(val, __float22half2_rn(tmp));
}
};
#ifdef ENABLE_BF16
template<>
struct GeluActivation<__nv_bfloat162> {
using return_type = __nv_bfloat162;
static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
{
__nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val));
float2 tmp_pow = bf1622float2(val_pow3);
float2 tmp = bf1622float2(val);
tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y));
}
};
#endif
/* Relu Activation */
template<typename T>
struct ReluActivation {
using return_type = T;
static __device__ __forceinline__ T apply(const T& val)
{
return val > static_cast<T>(0.0f) ? val : static_cast<T>(0.0f);
}
};
template<>
struct ReluActivation<half2> {
using return_type = half2;
static __device__ __forceinline__ half2 apply(const half2& val)
{
const half zero_half = static_cast<half>(0.0f);
return make_half2(val.x > zero_half ? val.x : zero_half, val.y > zero_half ? val.y : zero_half);
}
};
#ifdef ENABLE_BF16
template<>
struct ReluActivation<__nv_bfloat162> {
using return_type = __nv_bfloat162;
static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
{
const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);
return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
}
};
#endif
/* Silu Activation */
template<typename T>
struct SiluActivation {
using return_type = T;
static __device__ __forceinline__ T apply(const T& val)
{
return (T)((float)val / (1.0f + __expf((float)-val)));
}
};
template<>
struct SiluActivation<half2> {
using return_type = float2;
static __device__ __forceinline__ float2 apply(const half2& val)
{
return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y));
}
};
#ifdef ENABLE_BF16
template<>
struct SiluActivation<__nv_bfloat162> {
using return_type = float2;
static __device__ __forceinline__ float2 apply(const __nv_bfloat162& val)
{
return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y));
}
};
#endif // ENABLE_BF16
/* Identity Activation (= no activation) */
template<typename T>
struct IdentityActivation {
using return_type = T;
static __device__ __forceinline__ T apply(const T& val)
{
return val;
}
};
// clang-format off
template<template<typename T> class Activation, typename T, typename BT>
__global__ void generic_activation(T* out,
const BT* __restrict bias,
const T* __restrict gated_weights,
const BT* __restrict gated_bias,
const int* __restrict ia3_tasks,
const T* __restrict ia3_weights,
const int int8_mode,
const float* __restrict activation_in,
const float* __restrict activation_out,
const int* __restrict padding_offset,
const int seq_len,
int m,
int n)
{
constexpr size_t packed_elems = num_elems<T>::value;
const bool with_bias = bias != nullptr;
const bool with_gate = gated_weights != nullptr;
// const bool with_ia3 = ia3_tasks != nullptr;
using Act_T = typename Activation<T>::return_type;
using Float_T = typename packed_as<float, packed_elems>::type;
using Packed_Int8_t = typename packed_as<int8_t, packed_elems>::type;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
T val;
if (int8_mode == 2) {
// val = cuda_cast<T>(cuda_cast<Float_T>(reinterpret_cast<Packed_Int8_t*>(out)[id]) * activation_in[0]);
}
else {
val = out[id];
}
T gated_val;
if (with_gate) {
gated_val = gated_weights[id];
}
// if (with_bias) {
// const T reg_bias = static_cast<T>(bias[id % n]);
// val = val + reg_bias;
// if (with_gate) {
// const T reg_gated_bias = static_cast<T>(gated_bias[id % n]);
// gated_val = gated_val + reg_gated_bias;
// }
// }
if (with_gate) {
val = cuda_cast<T>(Activation<T>::apply(val) * cuda_cast<Act_T>(gated_val));
}
else {
// val = cuda_cast<T>(Activation<T>::apply(val));
}
// if (with_ia3) {
// const int word_id = id / n;
// const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
// const int batch_id = (word_id + offset) / seq_len;
// const int task = ia3_tasks[batch_id];
// val = val * ia3_weights[task * n + (id % n)];
// }
if (int8_mode != 2) {
out[id] = val;
}
else {
// reinterpret_cast<Packed_Int8_t*>(out)[id] =
// cuda_cast<Packed_Int8_t>(cuda_cast<Float_T>(val) * activation_out[0]);
}
}
}
// clang-format on
template<template<typename T> class Activation, typename T, typename BT>
void invokeGenericActivation(T* out,
const BT* bias,
const T* gated_weights,
const BT* gated_bias,
const int* ia3_tasks,
const T* ia3_weights,
const int m,
const int n,
const int int8_mode,
const float* activation_in,
const float* activation_out,
const int* padding_offset,
const int seq_len,
cudaStream_t stream)
{
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
FT_LOG_DEBUG("invokeGenericActivation %d %d %d", m, n, seq_len);
using PT = typename packed_type<T>::type;
constexpr int packed_elems = num_elems<PT>::value;
using PBT = typename packed_as<BT, packed_elems>::type;
const int n_threads = 512;
dim3 block, grid;
if (n / 4 / packed_elems <= n_threads) {
block.x = n / 4 / packed_elems;
grid.x = m;
}
else {
block.x = n_threads;
grid.x = ceil(m * n / double(n_threads));
}
FT_LOG_DEBUG("%d %d", grid.x, block.x);
sync_check_cuda_error();
generic_activation<Activation><<<grid, block, 0, stream>>>(reinterpret_cast<PT*>(out),
reinterpret_cast<const PBT*>(bias),
reinterpret_cast<const PT*>(gated_weights),
reinterpret_cast<const PBT*>(gated_bias),
ia3_tasks,
reinterpret_cast<const PT*>(ia3_weights),
int8_mode,
activation_in,
activation_out,
padding_offset,
seq_len,
m,
n / packed_elems);
sync_check_cuda_error();
}
#define INSTANTIATE_GENERIC_ACTIVATION(Activation, T, BT) \
template void invokeGenericActivation<Activation, T, BT>(T * out, \
const BT* bias, \
const T* gated_weights, \
const BT* gated_bias, \
const int* ia3_tasks, \
const T* ia3_weights, \
const int m, \
const int n, \
const int int8_mode, \
const float* activation_in, \
const float* activation_out, \
const int* padding_offset, \
const int seq_len, \
cudaStream_t stream);
INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float);
INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, __nv_bfloat16, __nv_bfloat16);
#endif
INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, float, float);
INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, half, half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, __nv_bfloat16, __nv_bfloat16);
#endif
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, float, float);
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16);
#endif
INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, float);
INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, half, half);
INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, __nv_bfloat16, __nv_bfloat16);
INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, __nv_bfloat16);
#endif
#undef INSTANCIATE_GENERIC_ACTIVATION
template<typename T>
__global__ void add_bias_tanh(T* out, const T* __restrict bias, int m, int n)
{
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
T val = out[id];
if (bias != nullptr) {
val = val + ldg(&bias[id % n]);
}
out[id] = tanhf(val);
}
}
template<>
__global__ void add_bias_tanh(half* out, const half* __restrict bias, int m, int n)
{
half2* out_ptr = (half2*)out;
const half2* bias_ptr = (half2*)bias;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
half2 val = out_ptr[id];
if (bias != nullptr) {
val = val + __ldg(&bias_ptr[id % n]);
}
val.x = tanhf(val.x);
val.y = tanhf(val.y);
out_ptr[id] = val;
}
}
#ifdef ENABLE_BF16
template<>
__global__ void add_bias_tanh(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n)
{
__nv_bfloat162* out_ptr = (__nv_bfloat162*)out;
const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
__nv_bfloat162 val = out_ptr[id];
if (bias != nullptr) {
val = bf16hadd2(val, ldg(&bias_ptr[id % n]));
}
val.x = tanhf(val.x);
val.y = tanhf(val.y);
out_ptr[id] = val;
}
}
#endif
template<typename T>
void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStream_t stream)
{
const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16
dim3 block, grid;
if (n / 4 / data_type_factor <= 1024) {
block.x = n / 4 / data_type_factor;
grid.x = m;
}
else {
block.x = 1024;
grid.x = ceil(m * n / 1024.);
}
add_bias_tanh<T><<<grid, block, 0, stream>>>(out, bias, m, n / data_type_factor);
}
template void invokeAddBiasTanh(float* out, const float* bias, const int m, const int n, cudaStream_t stream);
template void invokeAddBiasTanh(half* out, const half* bias, const int m, const int n, cudaStream_t stream);
#ifdef ENABLE_BF16
template void
invokeAddBiasTanh(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream);
#endif
template<typename T2, int N>
__global__ void addBiasGeluV2(T2* out,
const T2* __restrict bias,
const int* ia3_tasks,
const T2* ia3_weights,
const int size,
const int* padding_offset,
const int seq_len)
{
const bool with_ia3 = ia3_tasks != nullptr;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
T2 val = out[id];
if (bias != nullptr) {
T2 reg_bias = ldg(&bias[id % N]);
val = hadd2(val, reg_bias);
}
val = GeluActivation<T2>::apply(val);
if (with_ia3) {
const int word_id = id / N;
const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
const int batch_id = (word_id + offset) / seq_len;
const int task = ia3_tasks[batch_id];
val = val * ia3_weights[task * N + (id % N)];
}
out[id] = val;
}
}
template<typename T2, int N, int ELEMENT_PER_ROUND>
__global__ void addBiasGeluV3(T2* out,
const T2* __restrict bias,
const int* ia3_tasks,
const T2* ia3_weights,
const int size,
const int* padding_offset,
const int seq_len)
{
const bool with_ia3 = ia3_tasks != nullptr;
T2 buffer[ELEMENT_PER_ROUND];
T2 tmp_bias[ELEMENT_PER_ROUND];
for (int id = blockIdx.x * blockDim.x * ELEMENT_PER_ROUND + threadIdx.x * ELEMENT_PER_ROUND; id < size;
id += blockDim.x * gridDim.x * ELEMENT_PER_ROUND) {
#pragma unroll
for (int i = 0; i < ELEMENT_PER_ROUND; i++) {
buffer[i] = out[id + i];
if (bias != nullptr) {
tmp_bias[i] = ldg(&bias[(id + i) % N]);
}
}
#pragma unroll
for (int i = 0; i < ELEMENT_PER_ROUND; i++) {
if (bias != nullptr) {
buffer[i] = hadd2(buffer[i], tmp_bias[i]);
}
buffer[i] = GeluActivation<T2>::apply(buffer[i]);
if (with_ia3) {
const int word_id = (id + i) / N;
const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
const int batch_id = (word_id + offset) / seq_len;
const int task = ia3_tasks[batch_id];
buffer[i] = buffer[i] * ia3_weights[task * N + ((id + i) % N)];
}
out[id + i] = buffer[i];
}
}
}
#define ADD_BIAS_GELU(HALF_N, ELEMENT_PER_ROUND) \
case HALF_N: \
if (ELEMENT_PER_ROUND > 1) { \
grid.x = grid.x / ELEMENT_PER_ROUND; \
addBiasGeluV3<T2, HALF_N, ELEMENT_PER_ROUND><<<grid, block, 0, stream>>>( \
(T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \
} \
else { \
addBiasGeluV2<T2, HALF_N><<<grid, block, 0, stream>>>( \
(T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \
} \
break;
template<typename T>
void invokeAddBiasGeluV2(T* out,
const T* bias,
const int* ia3_tasks,
const T* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream)
{
if (n % 2 == 0 && sizeof(T) == 2) {
const int half_n = n / 2;
dim3 block, grid;
block.x = std::min(half_n, 512);
grid.x = (m * half_n + (block.x - 1)) / block.x;
using T2 = typename TypeConverter<T>::Type;
if (grid.x >= 512) {
switch (half_n) {
ADD_BIAS_GELU(256, 1)
ADD_BIAS_GELU(512, 1)
ADD_BIAS_GELU(1024, 1)
ADD_BIAS_GELU(1536, 1)
ADD_BIAS_GELU(2048, 1)
ADD_BIAS_GELU(4096, 2)
ADD_BIAS_GELU(8192, 2)
ADD_BIAS_GELU(16384, 2)
ADD_BIAS_GELU(24576, 2)
ADD_BIAS_GELU(40960, 4)
default:
invokeGenericActivation<GeluActivation>(out,
bias,
(T*)nullptr,
(T*)nullptr,
ia3_tasks,
ia3_weights,
m,
n,
0,
(float*)nullptr,
(float*)nullptr,
padding_offset,
seq_len,
stream);
break;
}
}
else {
switch (half_n) {
ADD_BIAS_GELU(256, 1)
ADD_BIAS_GELU(512, 1)
ADD_BIAS_GELU(1024, 1)
ADD_BIAS_GELU(1536, 1)
ADD_BIAS_GELU(2048, 1)
ADD_BIAS_GELU(4096, 1)
ADD_BIAS_GELU(8192, 2)
ADD_BIAS_GELU(16384, 2)
ADD_BIAS_GELU(24576, 2)
ADD_BIAS_GELU(40960, 2)
default:
invokeGenericActivation<GeluActivation>(out,
bias,
(T*)nullptr,
(T*)nullptr,
ia3_tasks,
ia3_weights,
m,
n,
0,
(float*)nullptr,
(float*)nullptr,
padding_offset,
seq_len,
stream);
break;
}
}
}
else {
invokeGenericActivation<GeluActivation>(out,
bias,
(T*)nullptr,
(T*)nullptr,
ia3_tasks,
ia3_weights,
m,
n,
0,
(float*)nullptr,
(float*)nullptr,
padding_offset,
seq_len,
stream);
}
}
#undef ADD_BIAS_GELU
template void invokeAddBiasGeluV2(float* out,
const float* bias,
const int* ia3_tasks,
const float* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
template void invokeAddBiasGeluV2(half* out,
const half* bias,
const int* ia3_tasks,
const half* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeAddBiasGeluV2(__nv_bfloat16* out,
const __nv_bfloat16* bias,
const int* ia3_tasks,
const __nv_bfloat16* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
#endif // ENABLE_BF16
template<typename T>
__global__ void sigmoid_kernel(T* data, const int size, const float scale)
{
const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x;
if (index < size) {
float val = cuda_cast<float>(data[index]);
val = 1.0f / (1.0f + exp(-val)) * scale;
data[index] = T(val);
}
}
template<>
__global__ void sigmoid_kernel(half2* data, const int size, const float scale)
{
const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x;
if (index < size / 2) {
half2 val = data[index];
float2 val_float2 = cuda_cast<float2>(val);
val_float2.x = 1.0f / (1.0f + exp(-val_float2.x)) * scale;
val_float2.y = 1.0f / (1.0f + exp(-val_float2.y)) * scale;
data[index] = cuda_cast<half2>(val_float2);
}
}
template<typename T>
void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream)
{
if (std::is_same<T, float>::value || (size % 2 != 0)) {
dim3 block(128);
dim3 grid((size + 127) / 128);
sigmoid_kernel<<<grid, block, 0, stream>>>(data, size, scale);
}
else {
dim3 block(128);
dim3 grid((size + 255) / 256);
sigmoid_kernel<<<grid, block, 0, stream>>>((half2*)data, size, scale);
}
}
template void invokeSigmoid(float* data, const int size, const float scale, cudaStream_t stream);
template void invokeSigmoid(half* data, const int size, const float scale, cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <stdlib.h>
namespace fastertransformer {
// clang-format off
template<typename T> struct GeluActivation;
template<typename T> struct ReluActivation;
template<typename T> struct SiluActivation;
template<typename T> struct IdentityActivation;
// clang-format on
template<template<typename T> class Activation, typename T, typename BT>
void invokeGenericActivation(T* out,
const BT* bias,
const T* gated_weights,
const BT* gated_bias,
const int* ia3_tasks,
const T* ia3_weights,
const int m,
const int n,
const int int8_mode,
const float* activation_in,
const float* activation_out,
const int* padding_offset,
const int seq_len,
cudaStream_t stream);
template<template<typename T> class Activation, typename T, typename BT>
void invokeGenericActivation(T* out,
const BT* bias,
const T* gated_weights,
const BT* gated_bias,
const int* ia3_tasks,
const T* ia3_weights,
const int m,
const int n,
const int int8_mode,
const float* activation_in,
const float* activation_out,
cudaStream_t stream)
{
invokeGenericActivation<Activation, T, BT>(out,
bias,
gated_weights,
gated_bias,
ia3_tasks,
ia3_weights,
m,
n,
int8_mode,
activation_in,
activation_out,
(const int*)nullptr,
0,
stream);
}
template<typename T>
void invokeAddBiasGeluV2(T* out,
const T* bias,
const int* ia3_tasks,
const T* ia3_weights,
const int* padding_offset,
const int seq_len,
const int m,
const int n,
cudaStream_t stream);
template<typename T>
void invokeAddBias(T* out, T const* bias, const int m, const int n, cudaStream_t stream)
{
invokeGenericActivation<IdentityActivation, T, T>(
out, bias, nullptr, nullptr, nullptr, nullptr, m, n, 0, nullptr, nullptr, stream);
}
template<typename T>
void invokeAddBiasGeluV2(
T* out, const T* bias, const int* ia3_tasks, const T* ia3_weights, const int m, const int n, cudaStream_t stream)
{
invokeAddBiasGeluV2(out, bias, ia3_tasks, ia3_weights, nullptr, 0, m, n, stream);
}
template<typename T>
void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStream_t stream);
template<typename T>
void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/ban_bad_words.h"
#include "src/fastertransformer/utils/cuda_utils.h"
namespace fastertransformer {
template<typename T>
__global__ void ban_bad_words(T* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int beam_width,
const int* bad_words,
size_t bad_words_len,
bool share_words,
int id_offset,
int vocab_size_padded,
size_t step)
{
const int id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_idx = blockIdx.y / beam_width;
const int beam_idx = blockIdx.y % beam_width;
const int* base_bad_words = share_words ? bad_words : bad_words + batch_idx * 2 * bad_words_len;
const int* base_bad_words_offsets = base_bad_words + bad_words_len;
if (id >= bad_words_len || base_bad_words_offsets[id] < 0) {
return;
}
const int item_end = base_bad_words_offsets[id];
const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;
const int item_size = item_end - item_start;
/* The single-token case unconditionally bans the token */
bool should_ban = item_size == 1;
/* Multi-token case and enough previously generated tokens to look for a match */
if (item_size > 1 && step >= item_size - 1) {
should_ban = true;
int parent_id = beam_idx;
const bool gather_beam = beam_width > 1;
for (int token_idx = item_size - 2; token_idx >= 0; token_idx--) {
const int previous_token = output_ids_buf[(step - (item_size - 1) + token_idx) * batch_size * beam_width
+ id_offset + batch_idx * beam_width + parent_id];
if (previous_token != base_bad_words[item_start + token_idx]) {
should_ban = false;
break;
}
if (gather_beam) {
parent_id = parent_ids_buf[(step - (item_size - 1) + token_idx) * beam_width * batch_size + id_offset
+ batch_idx * beam_width + parent_id];
if (parent_id < 0 || parent_id >= beam_width) {
should_ban = false;
break;
}
}
}
}
if (should_ban) {
int banned_token = base_bad_words[item_end - 1];
if (0 < banned_token && banned_token < vocab_size_padded) {
logits[batch_idx * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token] =
static_cast<T>(-INFINITY);
}
}
}
template<typename T>
void invokeBanBadWords(T* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream)
{
dim3 block, grid;
block.x = min(((bad_words_len + 32 - 1) / 32) * 32, 256UL);
grid.x = (bad_words_len + block.x - 1) / block.x;
grid.y = local_batch_size * beam_width;
ban_bad_words<<<grid, block, 0, stream>>>(logits,
output_ids_buf,
parent_ids_buf,
batch_size,
beam_width,
bad_words,
bad_words_len,
share_words,
id_offset,
vocab_size_padded,
step);
sync_check_cuda_error();
}
template void invokeBanBadWords(half* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBanBadWords(__nv_bfloat16* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);
#endif
template void invokeBanBadWords(float* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace fastertransformer {
template<typename T>
void invokeBanBadWords(T* logits,
const int* output_ids_buf,
const int* parent_ids_buf,
int batch_size,
int local_batch_size,
int beam_width,
const int* bad_words,
bool share_words,
size_t bad_words_len,
int id_offset,
int vocab_size_padded,
size_t step,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include "src/fastertransformer/kernels/beam_search_penalty_kernels.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
namespace fastertransformer {
template<typename T>
__global__ void add_bias_temperature(T* logits,
const T* bias,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const float temperature)
{
int tid = threadIdx.x;
int bid = blockIdx.x;
int bbid = blockIdx.y;
logits += bbid * vocab_size_padded;
const T MASK_VAL = (std::is_same<T, half>::value) ? -HALF_FLT_MAX : -FLT_MAX;
const T inv_temp = static_cast<T>(1.0f / (temperature + 1e-6f));
for (int i = tid + bid * blockDim.x; i < vocab_size_padded; i += blockDim.x * gridDim.x) {
if (i < vocab_size) {
T bias_val = bias == nullptr ? (T)(0.0f) : bias[i];
logits[i] = (logits[i] + bias_val) * inv_temp;
}
else {
logits[i] = MASK_VAL;
}
}
}
template<>
__global__ void add_bias_temperature(half2* logits,
const half2* bias,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const float temperature)
{
assert(vocab_size % 2 == 0);
assert(vocab_size_padded % 2 == 0);
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int bbid = blockIdx.y;
const half2 mask_val = __float2half2_rn(-HALF_FLT_MAX);
const half2 inv_temp = __float2half2_rn(1.0f / (temperature + 1e-6f));
const int half_vocab_size = vocab_size / 2;
const int half_vocab_size_padded = vocab_size_padded / 2;
logits += bbid * half_vocab_size_padded;
for (int index = tid + bid * blockDim.x; index < half_vocab_size_padded; index += blockDim.x * gridDim.x) {
int vocab_idx = index % half_vocab_size_padded;
half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val;
if (vocab_idx < half_vocab_size) {
if (bias != nullptr) {
logit = __hadd2(logit, bias[vocab_idx]);
}
logit = __hmul2(logit, inv_temp);
}
logits[index] = logit;
}
}
template<typename T, bool IS_ADDITIVE>
__global__ void apply_repetition_penalty(T* logits,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const int step,
const int* current_ids,
const int* previous_ids,
const int* parent_ids,
const int* input_lengths,
const int max_input_length,
const float repetition_penalty)
{
assert(step > 0);
const int tid = threadIdx.x;
const int bbid = blockIdx.x;
const int batch_id = bbid / beam_width;
const int bbsize = batch_size * beam_width;
logits += bbid * vocab_size_padded;
extern __shared__ char sbuf[];
T* penalty_logits = reinterpret_cast<T*>(sbuf);
// prevent misaligment when sizeof(T) = 2
int* penalty_indices = reinterpret_cast<int*>(sbuf + (sizeof(T) * step + 31) / 32 * 32);
const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length;
if (tid == 0) {
T repet_penalty = static_cast<T>(repetition_penalty);
int prev_id = current_ids[bbid];
T prev_logit = logits[prev_id];
penalty_indices[step - 1] = prev_id;
if (IS_ADDITIVE) {
penalty_logits[step - 1] = prev_logit - repet_penalty;
}
else {
penalty_logits[step - 1] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty;
}
if (step > 1) {
int parent_beam = bbid % beam_width;
for (int i = step - 2; i >= 0; --i) {
// Skip the padded tokens.
if (i >= input_length && i < max_input_length) {
continue;
}
parent_beam = parent_ids[i * bbsize + batch_id * beam_width + parent_beam];
prev_id = previous_ids[i * bbsize + batch_id * beam_width + parent_beam];
prev_logit = logits[prev_id];
penalty_indices[i] = prev_id;
if (IS_ADDITIVE) {
penalty_logits[i] = prev_logit - repet_penalty;
}
else {
penalty_logits[i] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty;
}
}
}
}
__syncthreads();
for (int i = tid; i < step; i += blockDim.x) {
if (i >= input_length && i < max_input_length) {
continue;
}
logits[penalty_indices[i]] = penalty_logits[i];
}
}
template<typename T>
__global__ void apply_min_length_penalty(T* logits,
const int min_length,
const int* end_ids,
const int* sequence_lengths,
const int max_input_length,
const int beam_width,
const int vocab_size_padded)
{
int bbid = threadIdx.x + blockIdx.x * blockDim.x; // batch-beam index
int bid = bbid / beam_width; // batch index
// We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1,
// which is equal to the length of k/v caches.
if (sequence_lengths[bbid] + 1 - max_input_length < min_length) {
T mask_val = (std::is_same<T, half>::value) ? -HALF_FLT_MAX : -FLT_MAX;
logits[bbid * vocab_size_padded + end_ids[bid]] = mask_val;
}
}
template<typename T>
void invokeAddBiasApplyPenalties(int step,
T* logits,
const int* current_ids,
const int* previous_ids,
const int* parent_ids,
const int* input_lengths,
const int* sequence_lengths,
const T* bias,
const int ite,
const int max_input_length,
const int local_batch_size,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const int* end_ids,
const float temperature,
const float repetition_penalty,
const RepetitionPenaltyType repetition_penalty_type,
const int min_length,
cudaStream_t stream)
{
if (bias != nullptr || temperature != 1.0f || vocab_size != vocab_size_padded) {
dim3 block(512);
if (std::is_same<T, half>::value && vocab_size % 2 == 0 && vocab_size_padded % 2 == 0) {
dim3 grid((vocab_size_padded / 2 + block.x - 1) / block.x, beam_width * local_batch_size);
add_bias_temperature<<<grid, block, 0, stream>>>(reinterpret_cast<half2*>(logits),
reinterpret_cast<const half2*>(bias),
batch_size,
beam_width,
vocab_size,
vocab_size_padded,
temperature);
}
else {
dim3 grid((vocab_size_padded + block.x - 1) / block.x, beam_width * local_batch_size);
add_bias_temperature<<<grid, block, 0, stream>>>(
logits, bias, batch_size, beam_width, vocab_size, vocab_size_padded, temperature);
}
}
if (repetition_penalty_type != RepetitionPenaltyType::None && step > 0) {
if (repetition_penalty != getDefaultPenaltyValue(repetition_penalty_type)) {
size_t smem_size = (sizeof(T) * step + 31) / 32 * 32 + sizeof(int) * step;
dim3 block(256);
dim3 grid(beam_width * local_batch_size);
if (repetition_penalty_type == RepetitionPenaltyType::Multiplicative) {
apply_repetition_penalty<T, false>
<<<grid, block, smem_size, stream>>>(logits,
batch_size,
beam_width,
vocab_size,
vocab_size_padded,
step,
current_ids,
previous_ids,
// TODO(jaedeokk):
// Remove (+ite ...) by getting parent_ids with offset
// and then remove 'ite' argument from the function.
parent_ids + ite * beam_width * local_batch_size,
input_lengths,
max_input_length,
repetition_penalty);
}
else if (repetition_penalty_type == RepetitionPenaltyType::Additive) {
apply_repetition_penalty<T, true>
<<<grid, block, smem_size, stream>>>(logits,
batch_size,
beam_width,
vocab_size,
vocab_size_padded,
step,
current_ids,
previous_ids,
parent_ids + ite * beam_width * local_batch_size,
input_lengths,
max_input_length,
repetition_penalty);
}
}
}
if (step - max_input_length < min_length) {
FT_CHECK_WITH_INFO(sequence_lengths != nullptr, "Need sequence_lengths to apply min length penlaty");
FT_CHECK_WITH_INFO(end_ids != nullptr, "Need end_id to apply min length penlaty");
const int block_size = min(local_batch_size * beam_width, 1024);
const int grid_size = (local_batch_size * beam_width + block_size - 1) / block_size;
apply_min_length_penalty<<<grid_size, block_size, 0, stream>>>(
logits, min_length, end_ids, sequence_lengths, max_input_length, beam_width, vocab_size_padded);
}
}
template void invokeAddBiasApplyPenalties(int step,
float* logits,
const int* current_ids,
const int* previous_ids,
const int* parent_ids,
const int* input_lengths,
const int* sequence_lengths,
const float* bias,
const int ite,
const int max_input_length,
const int local_batch_size,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const int* end_ids,
const float temperature,
const float repetition_penalty,
const RepetitionPenaltyType repetition_penalty_type,
const int min_length,
cudaStream_t stream);
template void invokeAddBiasApplyPenalties(int step,
half* logits,
const int* current_ids,
const int* previous_ids,
const int* parent_ids,
const int* input_lengths,
const int* sequence_lengths,
const half* bias,
const int ite,
const int max_input_length,
const int local_batch_size,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const int* end_ids,
const float temperature,
const float repetition_penalty,
const RepetitionPenaltyType repetition_penalty_type,
const int min_length,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include "src/fastertransformer/kernels/penalty_types.h"
#include "src/fastertransformer/utils/cuda_utils.h"
namespace fastertransformer {
template<typename T>
void invokeAddBiasApplyPenalties(int step,
T* logits,
const int* current_ids,
const int* previous_ids,
const int* parent_ids,
const int* input_lengths,
const int* sequence_lengths,
const T* bias,
const int ite,
const int max_input_length,
const int local_batch_size,
const int batch_size,
const int beam_width,
const int vocab_size,
const int vocab_size_padded,
const int* end_ids,
const float temperature,
const float repetition_penalty,
const RepetitionPenaltyType repetition_penalty_type,
const int min_length,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "src/fastertransformer/kernels/beam_search_topk_kernels.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "src/fastertransformer/utils/logger.h"
namespace fastertransformer {
template<typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
{
// score = log(prob) / (length)^length_penalty.
if (length_penalty == 0.0f || length == 1) {
return log_prob;
}
return log_prob / static_cast<T>(powf((float)length, length_penalty));
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel(const T* log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int* sequence_lengths,
const int vocab_size,
T diversity_rate,
float length_penalty)
{
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int thread_id = threadIdx.x;
int block_id = blockIdx.x; // batch beam index.
TopK<T, MAX_K> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -MAX_T_VAL;
}
#pragma unroll
for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) {
int index = elem_id + block_id * vocab_size;
T score = length_penalty == 0.0f ? log_probs[index] :
apply_length_penalty(log_probs[index],
finished[block_id] ? sequence_lengths[block_id] :
sequence_lengths[block_id] + 1,
length_penalty);
partial.insert(score, index);
}
TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);
if (thread_id == 0) {
int index = block_id * MAX_K;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
topk_tmp_id_buf[index + i] = total.p[i];
topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i;
}
}
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
{
int thread_id = threadIdx.x;
int block_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
TopK<T, MAX_K> partial;
if (thread_id == 0) {
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -MAX_T_VAL;
}
int index = block_id * MAX_K * MAX_K;
for (int i = 0; i < MAX_K * MAX_K; i++) {
partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]);
}
index = block_id * MAX_K;
for (int i = 0; i < MAX_K; i++) {
id_buf[index + i] = partial.p[i];
}
}
}
template<typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void batch_topK_kernel_v2(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf)
{
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int tid = threadIdx.x;
int bid = blockIdx.x;
TopK<T, MAX_K> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
#pragma unroll
for (int i = 0; i < MAX_K; ++i) {
partial.p[i] = -1;
partial.u[i] = -MAX_T_VAL;
}
int ite = MAX_K * MAX_K / THREADBLOCK_SIZE;
#pragma unroll
for (int i = 0; i < ite; i++) {
int index = bid * MAX_K * MAX_K + i * THREADBLOCK_SIZE + tid;
partial.insert((T)topk_tmp_val_buf[index], topk_tmp_id_buf[index]);
}
TopK<T, MAX_K> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, MAX_K>);
if (tid == 0) {
#pragma unroll
for (int i = 0; i < MAX_K; i++) {
id_buf[bid * MAX_K + i] = total.p[i];
}
}
}
template<typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topk_stage_1_opt3(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int* sequence_lengths,
const int k,
const int vocab_size,
const float length_penalty,
const int* end_ids)
{
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index)
const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
const int tmp_log_buf_index = row_id * vocab_size;
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
TopK_2<T> partial;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
if (finished != nullptr && finished[row_id] == true) {
if (tid < k) {
const int index = tmp_topk_buf_index + tid;
if (block_lane == 0 && tid == 0) {
const int end_id = end_ids[row_id / k];
topk_tmp_id_buf[index] = tmp_log_buf_index + end_id;
topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id];
}
else {
topk_tmp_id_buf[index] = -1;
topk_tmp_val_buf[index] = -MAX_T_VAL;
}
}
return;
}
for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size;
elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) {
int index = elem_id + tmp_log_buf_index;
tmp_log_probs[index] = log_probs[index];
}
for (int ite = 0; ite < k; ite++) {
partial.init();
#pragma unroll
for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size;
elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) {
int index = elem_id + tmp_log_buf_index;
partial.insert(tmp_log_probs[index], index);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
if (tid == 0) {
const int index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p;
topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL;
}
__syncthreads();
}
}
template<typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf,
T* topk_tmp_val_buf,
int* ids,
BeamHypotheses beam_hyps,
const int* end_ids,
const int vocab_size,
const int k)
{
const int size = k * k * BLOCKS_PER_BEAM_;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
extern __shared__ char array[];
T* s_val = topk_tmp_val_buf + batch_id * size;
int* s_id = (int*)(array);
__shared__ int selected_beams;
__shared__ bool is_stop;
if (tid == 0) {
selected_beams = 0;
is_stop = false;
}
__syncthreads();
if (beam_hyps.num_beams != nullptr) {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) {
// initialize the buffer
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
}
else if (beam_hyps.num_beams[global_batch_idx] == k) {
return;
}
}
TopK_2<T> partial;
// In some cases, we may encounter k finished sentences, but scores are bad. So, the max iteration
// is 2*k here
for (int ite = 0; ite < 2 * k; ite++) {
partial.init();
#pragma unroll
for (int i = tid; i < size; i += BLOCK_SIZE_) {
partial.insert(s_val[i], i);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
if (tid == 0) {
if (beam_hyps.num_beams != nullptr
&& topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) {
// if beam_token does not belong to top num_beams tokens, it should not be added. Refer from
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
if (ite >= k) {
s_val[total.p] = -MAX_T_VAL;
}
else {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
const float normed_score =
apply_length_penalty(s_val[total.p], beam_hyps.step, beam_hyps.length_penalty);
const int num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of selected candidatet
// is higher than min_normed_score or not. If current score is better, replace worst one
// and update the min_normed_score.
if (num_beam == k) {
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) {
// end the tracing and exist this for loop
selected_beams = k;
is_stop = true;
break;
}
else {
// find the beam index which's score = min_normed_score, erase it.
for (int j = 0; j < k; j++) {
if (beam_hyps.normed_scores[global_batch_idx * k + j]
== beam_hyps.min_normed_scores[global_batch_idx]) {
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score;
for (int l = 0; l < k; l++) {
beam_hyps.min_normed_scores[global_batch_idx] =
min(beam_hyps.min_normed_scores[global_batch_idx],
beam_hyps.normed_scores[global_batch_idx * k + l]);
}
break;
}
}
}
}
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
* (beam_hyps.max_seq_len);
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
for (int j = beam_hyps.step - 1; j >= 0; j--) {
const int src_idx = j * beam_hyps.batch_size * k
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
prev_id = beam_hyps.parent_ids_src[src_idx];
}
const int tgt_beam_idx = global_batch_idx * k + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx] =
min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
s_val[total.p] = -MAX_T_VAL;
beam_hyps.num_beams[global_batch_idx]++;
}
}
else {
s_id[selected_beams] = total.p;
s_val[total.p] = -MAX_T_VAL;
selected_beams++;
}
}
__syncthreads();
if (selected_beams >= k) {
break;
}
}
if (tid < k && is_stop == false) {
ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]];
}
}
template<typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs,
T* tmp_log_probs,
int* topk_tmp_id_buf,
T* topk_tmp_val_buf,
const bool* finished,
const int* sequence_lengths,
const int k,
const int vocab_size,
const float length_penalty)
{
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs
const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam
const int tmp_log_buf_index = row_id * vocab_size;
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k;
TopK_2<T> partial;
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) {
int index = elem_id + tmp_log_buf_index;
tmp_log_probs[index] = log_probs[index];
}
for (int ite = 0; ite < k; ite++) {
partial.init();
#pragma unroll
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size;
elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) {
int index = elem_id + tmp_log_buf_index;
partial.insert(tmp_log_probs[index], index);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
if (tid == 0) {
const int index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p;
topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL;
}
__syncthreads();
}
}
template<typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
T* topk_tmp_val_buf,
int* ids,
BeamHypotheses beam_hyps,
const int* end_ids,
const int k,
const int vocab_size)
{
const int size = k * k * BLOCKS_PER_BEAM;
const int tid = threadIdx.x;
const int batch_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
extern __shared__ char array[];
T* s_val = topk_tmp_val_buf + batch_id * size;
int* s_id = (int*)(array);
__shared__ int selected_beams;
__shared__ bool is_stop;
if (tid == 0) {
selected_beams = 0;
is_stop = false;
}
__syncthreads();
if (beam_hyps.num_beams != nullptr) {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) {
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
}
else if (beam_hyps.num_beams[global_batch_idx] == k) {
return;
}
}
TopK_2<T> partial;
// In some cases, we may encounter k finished sentences, but scores are bad. So, the max iteration
// is 2*k here
for (int ite = 0; ite < 2 * k; ite++) {
partial.init();
#pragma unroll
for (int i = tid; i < size; i += BLOCK_SIZE) {
partial.insert(s_val[i], i);
}
TopK_2<T> total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2<T>);
if (tid == 0) {
if (beam_hyps.num_beams != nullptr
&& topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) {
// if beam_token does not belong to top num_beams tokens, it should not be added. Refer from
// https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257
if (ite >= k) {
s_val[total.p] = -MAX_T_VAL;
}
else {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
const float normed_score =
apply_length_penalty(s_val[total.p], beam_hyps.step, beam_hyps.length_penalty);
const int num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of selected candidatet
// is higher than min_normed_score or not. If current score is better, replace worst one
// and update the min_normed_score.
if (num_beam == k) {
if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) {
// end the tracing and exist this for loop
selected_beams = k;
is_stop = true;
break;
}
else {
// find the beam index which's score = min_normed_score, erase it.
for (int j = 0; j < k; j++) {
if (beam_hyps.normed_scores[global_batch_idx * k + j]
== beam_hyps.min_normed_scores[global_batch_idx]) {
beam_idx = j;
beam_hyps.num_beams[global_batch_idx]--;
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score;
for (int l = 0; l < k; l++) {
beam_hyps.min_normed_scores[global_batch_idx] =
min(beam_hyps.min_normed_scores[global_batch_idx],
beam_hyps.normed_scores[global_batch_idx * k + l]);
}
break;
}
}
}
}
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
* (beam_hyps.max_seq_len);
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
for (int j = beam_hyps.step - 1; j >= 0; j--) {
const int src_idx = j * beam_hyps.batch_size * k
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
prev_id = beam_hyps.parent_ids_src[src_idx];
}
const int tgt_beam_idx = global_batch_idx * k + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx] =
min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]);
s_val[total.p] = -MAX_T_VAL;
beam_hyps.num_beams[global_batch_idx]++;
}
}
else {
s_id[selected_beams] = total.p;
s_val[total.p] = -MAX_T_VAL;
selected_beams++;
}
}
__syncthreads();
if (selected_beams >= k) {
break;
}
}
if (tid < k && is_stop == false) {
ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]];
}
}
#define CASE_K_DIV(K, BLOCK_SIZE_1, BLOCK_SIZE_2) \
case K: \
beam_topK_kernel<T, K, BLOCK_SIZE_2><<<batch_size * beam_width, BLOCK_SIZE_2, 0, stream>>>(log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
sequence_lengths, \
vocab_size, \
diversity_rate, \
length_penalty); \
if (K < 10) \
batch_topK_kernel<T, K, BLOCK_SIZE_1> \
<<<batch_size, BLOCK_SIZE_1, 0, stream>>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \
else \
batch_topK_kernel_v2<T, K, 32><<<batch_size, 32, 0, stream>>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \
break;
#define CASE_K(K, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \
case K: \
topk_stage_1_opt3<float, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_> \
<<<batch_size * K * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>(log_probs, \
temp_log_probs, \
topk_tmp_id_buf, \
topk_tmp_val_buf, \
finished, \
sequence_lengths, \
beam_width, \
vocab_size, \
length_penalty, \
end_ids); \
topk_stage_2_opt3<float, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
<<<batch_size, BLOCK_SIZE_2_, K * sizeof(int), stream>>>( \
topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, vocab_size, beam_width); \
sync_check_cuda_error(); \
break;
template<typename T>
void invokeTopkBeamSearch(void* workspace,
size_t& workspace_size,
T* log_probs,
int* ids,
BeamHypotheses* beam_hyps,
const bool* finished,
const int* sequence_lengths,
const int batch_size,
const int beam_width,
const int vocab_size_padded_,
const T diversity_rate,
const float length_penalty,
const int* end_ids,
cudaStream_t stream)
{
FT_LOG_DEBUG("%s", __PRETTY_FUNCTION__);
// log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a token.
const int vocab_size = vocab_size_padded_;
// Beam size should be less than or equal to vocab size.
assert(beam_width <= vocab_size);
// Beam search needs the sequence lengths of beams to apply length penalty.
assert(length_penalty == 0.0f || sequence_lengths != nullptr);
const int max_block_per_beam = 8;
int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float
int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int
int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float
// prevent memory misaligned address
temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4;
topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4;
topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4;
if (workspace == nullptr) {
workspace_size = sizeof(float) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size
+ sizeof(float) * topk_tmp_val_buf_size;
return;
}
else {
T* temp_log_probs = (T*)workspace;
int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size);
T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size);
if (diversity_rate == 0.0f) {
switch (beam_width) {
CASE_K(1, 128, 128, 8);
CASE_K(4, 128, 128, 8);
CASE_K(10, 128, 128, 8);
CASE_K(16, 128, 128, 5);
CASE_K(32, 256, 128, 1);
CASE_K(64, 256, 256, 1);
default:
topk_stage_1_opt2_general<T, 128, 1>
<<<batch_size * beam_width * 1, 128, 0, stream>>>(log_probs,
temp_log_probs,
topk_tmp_id_buf,
topk_tmp_val_buf,
finished,
sequence_lengths,
beam_width,
vocab_size,
length_penalty);
topk_stage_2_opt2_general<T, 128, 1>
<<<batch_size,
128,
beam_width * beam_width * 1 * sizeof(float) + beam_width * sizeof(int),
stream>>>(
topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, beam_width, vocab_size);
break;
}
}
else {
switch (beam_width) {
CASE_K_DIV(1, 256, 256);
CASE_K_DIV(4, 256, 256);
CASE_K_DIV(16, 256, 64);
CASE_K_DIV(32, 256, 64);
CASE_K_DIV(64, 256, 64);
default:
FT_CHECK_WITH_INFO(false, fmtstr("Topk kernel does not support beamwidth = %d \n", beam_width));
break;
}
}
return;
}
}
#undef CASE_K
#undef CASE_K_DIV
template void invokeTopkBeamSearch(void* workspace,
size_t& workspace_size,
float* log_probs,
int* ids,
BeamHypotheses* beam_hyps,
const bool* finished,
const int* sequence_lengths,
const int batch_size,
const int beam_width,
const int vocab_size_padded_,
const float diversity_rate,
const float length_penalty,
const int* end_ids,
cudaStream_t stream);
template<typename T>
__global__ void tileEncoderResults(T* tiled_output,
int* tiled_sequence_length,
const T* output,
const int* sequence_length,
const uint batch_size,
const uint beam_width,
const uint d_model)
{
if (blockIdx.x == 0) {
for (uint i = threadIdx.x; i < batch_size * beam_width; i += blockDim.x) {
tiled_sequence_length[i] = sequence_length[i / beam_width];
}
}
int tgt_offset =
blockIdx.x * gridDim.y * gridDim.z * d_model + blockIdx.y * gridDim.z * d_model + blockIdx.z * d_model;
int src_offset = blockIdx.x * gridDim.z * d_model + blockIdx.z * d_model;
for (uint i = threadIdx.x; i < d_model; i += blockDim.x) {
tiled_output[i + tgt_offset] = output[i + src_offset];
}
}
template<typename T>
void invokeTileEncoderResults(T* tiled_output,
int* tiled_sequence_length,
const T* output,
const int* sequence_length,
const size_t batch_size,
const size_t beam_width,
const size_t mem_max_seq_len,
const size_t d_model,
cudaStream_t stream)
{
// tiled_output: [batch_size, beam_width, mem_max_seq_len, d_model]
// tiled_sequence_length: [batch_size, beam_width]
// output: [batch_size, mem_max_seq_len, d_model]
// sequence_length [batch_size]
dim3 grid(batch_size, beam_width, mem_max_seq_len);
bool is_half2 = (std::is_same<T, half>::value) && (d_model % 2 == 0);
if (is_half2) {
using T2 = typename TypeConverter<T>::Type; // fp16 to half2, bf16 to bf162
dim3 block(min(512, (int)(d_model / 2)));
tileEncoderResults<T2><<<grid, block, 0, stream>>>((T2*)tiled_output,
tiled_sequence_length,
(const T2*)output,
sequence_length,
batch_size,
beam_width,
d_model / 2);
}
else {
dim3 block(min(512, (int)d_model));
tileEncoderResults<T><<<grid, block, 0, stream>>>(
tiled_output, tiled_sequence_length, output, sequence_length, batch_size, beam_width, d_model);
}
}
template void invokeTileEncoderResults(float* tiled_output,
int* tiled_sequence_length,
const float* output,
const int* sequence_length,
const size_t batch_size,
const size_t beam_width,
const size_t mem_max_seq_len,
const size_t d_model,
cudaStream_t stream);
template void invokeTileEncoderResults(half* tiled_output,
int* tiled_sequence_length,
const half* output,
const int* sequence_length,
const size_t batch_size,
const size_t beam_width,
const size_t mem_max_seq_len,
const size_t d_model,
cudaStream_t stream);
template void invokeTileEncoderResults(half2* tiled_output,
int* tiled_sequence_length,
const half2* output,
const int* sequence_length,
const size_t batch_size,
const size_t beam_width,
const size_t mem_max_seq_len,
const size_t d_model,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeTileEncoderResults(__nv_bfloat16* tiled_output,
int* tiled_sequence_length,
const __nv_bfloat16* output,
const int* sequence_length,
const size_t batch_size,
const size_t beam_width,
const size_t mem_max_seq_len,
const size_t d_model,
cudaStream_t stream);
#endif
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps,
const bool* finished,
const float* cum_log_probs,
const int batch_size,
const int beam_width)
{
const int bid = blockIdx.x;
const int tgt_start_idx = beam_hyps.num_beams[bid];
if (beam_hyps.is_done[bid]) {
return;
}
for (int i = 0; i < beam_width; i++) {
if (threadIdx.x == 0) {
const int src_beam_idx = bid * beam_width + i;
const int tgt_beam_idx = bid * beam_width * 2 + i + tgt_start_idx;
const int length = beam_hyps.sequence_lengths_src[src_beam_idx];
beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] =
beam_hyps.output_ids_src[length * batch_size * beam_width + src_beam_idx];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) {
beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] =
beam_hyps.log_probs_src[length * batch_size * beam_width + src_beam_idx];
}
int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + src_beam_idx];
for (int j = length - 1; j >= 0; j--) {
// output_ids_tgt need to use max_seq_len + 1 because its shape is
// [bs, beam_width, max_seq_len + 1]
beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] =
beam_hyps.output_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id];
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) {
beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] =
beam_hyps.log_probs_src[j * batch_size * beam_width + bid * beam_width + prev_id];
}
prev_id = beam_hyps.parent_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id];
}
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = length;
beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty(
cum_log_probs[src_beam_idx], finished[src_beam_idx] ? length + 1 : length, beam_hyps.length_penalty);
beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx];
beam_hyps.num_beams[bid]++;
}
}
}
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps,
const bool* finished,
const float* cum_log_probs,
const int batch_size,
const int beam_width,
cudaStream_t stream)
{
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
}
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#pragma once
namespace fastertransformer {
// In original beam search implementation, if a beam is finished, we set it as finished
// and only continue to do beam search on remain beams (namely, beam_width - 1 beams in next step)
//
// In this implementation, when a beam is finished, we trace the path and record it in output_ids_tgt,
// and also record the normalized scores. And the beam search continue to use `beam_width` beams in
// next step.
//
// After we collect `beam_width` beams, we will sort them by their norm_scores.
struct BeamHypotheses {
int* output_ids_tgt = nullptr;
int* sequence_lengths_tgt = nullptr;
float* cum_log_probs = nullptr; // cum_log
float* normed_scores = nullptr; // cum_log / (length**length_penalty)
float* log_probs = nullptr; // log probs of each generated token
float* min_normed_scores = nullptr; // record the min normed scores for each batch
int* num_beams = nullptr; // the number of finished beams we collect
bool* is_done = nullptr;
// Used to set inputs
const int* output_ids_src;
const int* parent_ids_src;
const int* sequence_lengths_src;
const int* end_ids;
const float* log_probs_src;
// some variables for kernels
int step;
int ite;
int batch_size;
int local_batch_size;
int max_seq_len;
float length_penalty;
bool early_stopping = true;
bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs
};
template<typename T>
void invokeTopkBeamSearch(void* workspace,
size_t& workspace_size,
T* log_probs,
int* ids,
BeamHypotheses* beam_hyps,
const bool* finished,
const int* sequence_lengths,
const int batch_size,
const int beam_width,
const int vocab_size_padded_,
const T diversity_rate,
const float length_penalty,
const int* end_ids,
cudaStream_t stream);
template<typename T>
void invokeTileEncoderResults(T* tiled_encoder_output,
int* tiled_encoder_sequence_length,
const T* encoder_output,
const int* encoder_sequence_length,
const size_t batch_size,
const size_t beam_width,
const size_t mem_max_seq_len,
const size_t d_model,
cudaStream_t stream);
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps,
const bool* finished,
const float* cum_log_probs,
const int batch_size,
const int beam_width,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bert_preprocess_kernels.h"
#include "src/fastertransformer/utils/cuda_bf16_fallbacks.cuh"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
namespace fastertransformer {
__global__ void getPaddingOffsetAndCuSeqLensKernel(size_t* h_valid_word_num,
int* tmp_mask_offset,
int* cu_seqlens,
const int* sequence_length,
const int batch_size,
const int max_seq_len)
{
// do cumulated sum
int total_seq_len = 0;
int cum_offset = 0;
int index = 0;
const bool calculate_cu_seqlens = cu_seqlens != nullptr;
for (int i = 0; i < batch_size; i++) {
const int seq_len = sequence_length[i];
if (calculate_cu_seqlens) {
cu_seqlens[i] = total_seq_len;
}
for (int j = 0; j < seq_len; j++) {
tmp_mask_offset[index] = cum_offset;
index++;
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
if (calculate_cu_seqlens) {
cu_seqlens[batch_size] = total_seq_len;
}
h_valid_word_num[0] = (size_t)total_seq_len;
}
void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num,
size_t* h_token_num,
int* tmp_mask_offset,
int* cu_seqlens,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream)
{
h_pinned_token_num[0] = 0;
getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>(
h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len);
while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
h_token_num[0] = h_pinned_token_num[0];
sync_check_cuda_error();
}
template<typename T>
__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len)
{
// sequence_lengths: [batch_size]
// attention_mask: [batch_size, 1, max_seq_len, max_seq_len]
attention_mask += blockIdx.x * max_seq_len * max_seq_len;
const int length = sequence_lengths[blockIdx.x];
for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) {
// int row_id = i / max_seq_len;
int col_id = i % max_seq_len;
// if (row_id < length && col_id < length) {
// TODO (bhsueh) check this modification is ok or not on other rmodel
if (col_id < length) {
attention_mask[i] = (T)(1.0f);
}
else {
attention_mask[i] = (T)(0.0f);
}
}
}
template<typename T>
void invokeBuildEncoderAttentionMask(
T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream)
{
buildEncoderAttentionMaskKernel<<<batch_size, 256, 0, stream>>>(attention_mask, sequence_lengths, max_seq_len);
}
template void invokeBuildEncoderAttentionMask(float* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
template void invokeBuildEncoderAttentionMask(half* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
#ifdef ENABLE_FP8
template void invokeBuildEncoderAttentionMask(__nv_fp8_e4m3* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeBuildEncoderAttentionMask(__nv_bfloat16* attention_mask,
const int* sequence_lengths,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
#endif
__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size)
{
// use for get tensorrt fused mha padding offset
// when we remove the padding
extern __shared__ int tmp_offset[];
if (threadIdx.x == 0) {
tmp_offset[0] = 0;
for (int i = 0; i < batch_size; i++) {
tmp_offset[i + 1] = tmp_offset[i] + sequence_length[i];
}
}
__syncthreads();
for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
trt_mha_padding_offset[i] = tmp_offset[i];
}
}
void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset,
const int* sequence_length,
const int batch_size,
cudaStream_t stream)
{
getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (batch_size + 1), stream>>>(
trt_mha_padding_offset, sequence_length, batch_size);
}
__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset,
const int* sequence_length,
const int request_batch_size,
const int request_seq_len)
{
// use for get tensorrt fused mha padding offset
// when we keep the padding
extern __shared__ int tmp_offset[];
if (threadIdx.x == 0) {
tmp_offset[0] = 0;
for (int i = 0; i < request_batch_size; i++) {
tmp_offset[i * 2 + 1] = tmp_offset[i * 2] + sequence_length[i];
tmp_offset[i * 2 + 2] = request_seq_len * (i + 1);
}
}
__syncthreads();
for (int i = threadIdx.x; i < 2 * request_batch_size + 1; i += blockDim.x) {
trt_mha_padding_offset[i] = tmp_offset[i];
}
}
void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset,
const int* sequence_length,
const int request_batch_size,
const int request_seq_len,
cudaStream_t stream)
{
getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (2 * request_batch_size + 1), stream>>>(
trt_mha_padding_offset, sequence_length, request_batch_size, request_seq_len);
}
template<typename T>
__global__ void rebuild_sequence_length_padding(const T* src, T* dst, const int* padding_offset, const int n)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int dst_seq_id = bid + padding_offset[bid];
const int src_seq_id = bid;
for (int i = tid; i < n; i += blockDim.x) {
dst[dst_seq_id * n + i] = src[src_seq_id * n + i];
}
}
template<typename T>
void invokeRebuildPadding(
T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream)
{
// src: [token_num, hidden_dim]
// dst: [batch_size*max_seq_len, hidden_dim]
rebuild_sequence_length_padding<<<token_num, 256, 0, stream>>>(src, dst, padding_offset, hidden_dim);
}
template<typename T>
void invokeRebuildPadding(
T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream);
template void invokeRebuildPadding(float* dst,
const float* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
template void invokeRebuildPadding(half* dst,
const half* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeRebuildPadding(__nv_bfloat16* dst,
const __nv_bfloat16* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
template void invokeRebuildPadding(__nv_fp8_e4m3* dst,
const __nv_fp8_e4m3* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
#endif // ENABLE_FP8
template<typename T>
__global__ void remove_padding(T* tgt, const T* src, const int* padding_offset, const int n)
{
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int src_seq_id = bid + padding_offset[bid];
const int tgt_seq_id = bid;
for (int i = tid; i < n; i += blockDim.x) {
tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i];
}
}
template<typename T>
void invokeRemovePadding(
T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream)
{
remove_padding<<<token_num, 256, 0, stream>>>(dst, src, padding_offset, hidden_dim);
}
template void invokeRemovePadding(float* dst,
const float* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
template void invokeRemovePadding(half* dst,
const half* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
#ifdef ENABLE_FP8
template void invokeRemovePadding(__nv_fp8_e4m3* dst,
const __nv_fp8_e4m3* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeRemovePadding(__nv_bfloat16* dst,
const __nv_bfloat16* src,
const int* padding_offset,
const int token_num,
const int hidden_dim,
cudaStream_t stream);
#endif
template<typename T>
__global__ void buildRelativeAttentionBias(T* relative_attention_bias,
const T* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance)
{
const int head_id = blockIdx.x;
for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) {
int row_id = seq_id / seq_len;
int col_id = seq_id % seq_len;
int relative_position = col_id - row_id;
int relative_buckets = 0;
int tmp_num_bucket = num_bucket;
if (is_bidirectional) {
tmp_num_bucket /= 2;
if (relative_position > 0) {
relative_buckets += tmp_num_bucket;
}
else {
relative_position *= -1;
}
}
else {
relative_position = abs(relative_position);
}
int max_exact = tmp_num_bucket / 2;
bool is_small = relative_position < max_exact;
int relative_position_if_large =
max_exact
+ (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact)
* (tmp_num_bucket - max_exact));
relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1);
relative_buckets += is_small ? relative_position : relative_position_if_large;
relative_attention_bias[head_id * seq_len * seq_len + seq_id] =
relative_attention_bias_table[head_id * num_bucket + relative_buckets];
}
}
template<typename T>
void invokeBuildRelativeAttentionBias(T* relative_attention_bias,
const T* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance,
const PositionEmbeddingType position_embedding_type,
cudaStream_t stream)
{
if (position_embedding_type == PositionEmbeddingType::absolute) {
return;
}
dim3 grid(head_num);
dim3 block(256);
buildRelativeAttentionBias<<<grid, block, 0, stream>>>(relative_attention_bias,
relative_attention_bias_table,
head_num,
seq_len,
num_bucket,
is_bidirectional,
max_distance);
}
template void invokeBuildRelativeAttentionBias(float* relative_attention_bias,
const float* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance,
const PositionEmbeddingType position_embedding_type,
cudaStream_t stream);
template void invokeBuildRelativeAttentionBias(half* relative_attention_bias,
const half* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance,
const PositionEmbeddingType position_embedding_type,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBuildRelativeAttentionBias(__nv_bfloat16* relative_attention_bias,
const __nv_bfloat16* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance,
const PositionEmbeddingType position_embedding_type,
cudaStream_t stream);
#endif
#ifdef ENABLE_FP8
template<typename T_OUT, typename T_IN>
__global__ void getLastTokenDequantize(getLastTokenDequantizeParam<T_OUT, T_IN> param)
{
param.output[blockIdx.x * param.d_model + threadIdx.x] =
(T_OUT)((float)param.input[blockIdx.x * param.max_seq_len * param.d_model + threadIdx.x]
* __ldg(param.input_scale));
}
template<typename T_OUT, typename T_IN>
void invokeGetLastTokenDequantize(getLastTokenDequantizeParam<T_OUT, T_IN> param)
{
FT_CHECK(param.d_model <= 1024);
getLastTokenDequantize<T_OUT, T_IN><<<param.batch_size, param.d_model, 0, param.stream>>>(param);
}
template void invokeGetLastTokenDequantize<__nv_bfloat16, __nv_fp8_e4m3>(
getLastTokenDequantizeParam<__nv_bfloat16, __nv_fp8_e4m3> param);
template<typename T_OUT, typename T_IN, QUANTIZE_MODE quantize_mode>
__global__ void quantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam<T_OUT, T_IN, quantize_mode> param)
{
for (int i = threadIdx.x; i < param.d_model; i += blockDim.x) {
int padded_row_id = blockIdx.x + (param.padding_offset == nullptr ? 0 : param.padding_offset[blockIdx.x]);
if (quantize_mode == QUANTIZE_MODE::PER_TENSOR) {
param.dst[padded_row_id * param.d_model + i] =
(T_OUT)((float)param.src[blockIdx.x * param.d_model + i] * __ldg(param.scale));
}
else if (quantize_mode == QUANTIZE_MODE::PER_CHANNEL) {
param.dst[padded_row_id * param.d_model + i] =
(T_OUT)((float)param.src[blockIdx.x * param.d_model + i] * __ldg(param.scale + i));
}
}
}
template<>
__global__ void
quantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam<half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR> param)
{
int padded_row_id = blockIdx.x + (param.padding_offset == nullptr ? 0 : __ldg(&param.padding_offset[blockIdx.x]));
__nv_fp8x4_e4m3* src_ptr = ((__nv_fp8x4_e4m3*)param.src) + blockIdx.x * (param.d_model / 4);
half2* dst_ptr = ((half2*)param.dst) + padded_row_id * (param.d_model / 2);
half2 scale = cuda_cast<half2>(__ldg(param.scale));
for (int i = threadIdx.x; i < param.d_model / 4; i += blockDim.x) {
half2 val_0;
half2 val_1;
fp8x4_e4m3_to_half2(&val_0, &val_1, src_ptr + i);
val_0 = hmul2(val_0, scale);
val_1 = hmul2(val_1, scale);
dst_ptr[2 * i + 0] = val_0;
dst_ptr[2 * i + 1] = val_1;
}
}
template<typename T_OUT, typename T_IN, QUANTIZE_MODE quantize_mode>
void invokeQuantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam<T_OUT, T_IN, quantize_mode> param)
{
dim3 grid(param.token_num);
dim3 block(param.d_model);
FT_CHECK(block.x <= 1024);
if (block.x % 4 == 0) {
block.x /= 4;
}
quantizeMatrixRebuildPadding<<<grid, block, 0, param.stream>>>(param);
}
template void invokeQuantizeMatrixRebuildPadding<half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR>(
QuantizeMatrixRebuildPaddingParam<half, __nv_fp8_e4m3, QUANTIZE_MODE::PER_TENSOR> param);
#endif
} // namespace fastertransformer
\ No newline at end of file
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/kernels/gen_relative_pos_bias.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#ifdef ENABLE_FP8
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#endif // ENABLE_FP8
namespace fastertransformer {
void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num,
size_t* h_token_num,
int* tmp_mask_offset,
int* cu_seqlens,
const int* sequence_length,
const int batch_size,
const int max_seq_len,
cudaStream_t stream);
inline void invokeGetPaddingOffset(size_t* h_pinned_token_num,
size_t* h_token_num,
int* tmp_mask_offset,
const int* sequence_length,
const int batch_size,
const int max_seq_len,
cudaStream_t stream)
{
invokeGetPaddingOffsetAndCuSeqLens(
h_pinned_token_num, h_token_num, tmp_mask_offset, nullptr, sequence_length, batch_size, max_seq_len, stream);
}
template<typename T>
void invokeBuildEncoderAttentionMask(
T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream);
void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset,
const int* sequence_length,
const int request_batch_size,
cudaStream_t stream);
void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset,
const int* sequence_length,
const int request_batch_size,
const int request_seq_len,
cudaStream_t stream);
template<typename T>
void invokeRebuildPadding(
T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream);
template<typename T>
void invokeRemovePadding(
T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream);
template<typename T>
void invokeBuildRelativeAttentionBias(T* relative_attention_bias,
const T* relative_attention_bias_table,
const int head_num,
const int seq_len,
const int num_bucket,
const bool is_bidirectional,
const int max_distance,
const PositionEmbeddingType position_embedding_type,
cudaStream_t stream);
template<typename T_OUT, typename T_IN>
struct getLastTokenDequantizeParam {
T_OUT* const output;
T_IN const* const input;
float const* const input_scale;
const int batch_size;
const int max_seq_len;
const int d_model;
cudaStream_t stream;
};
template<typename T_OUT, typename T_IN>
void invokeGetLastTokenDequantize(getLastTokenDequantizeParam<T_OUT, T_IN> param);
#ifdef ENABLE_FP8
template<typename T_OUT, typename T_IN, QUANTIZE_MODE quantize_mode>
struct QuantizeMatrixRebuildPaddingParam {
T_OUT* dst;
const T_IN* src;
const int* padding_offset;
const int token_num;
const int d_model;
const float* scale;
cudaStream_t stream;
};
template<typename T_OUT, typename T_IN, QUANTIZE_MODE quantize_mode>
void invokeQuantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam<T_OUT, T_IN, quantize_mode> param);
#endif // ENABLE_FP8
} // namespace fastertransformer
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "custom_ar_kernels.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
namespace fastertransformer {
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd2(const uint32_t& a, const uint32_t& b)
{
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t fadd(const uint32_t& a, const uint32_t& b)
{
uint32_t c;
asm volatile("add.f32 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
__threadfence_system();
asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr)
{
#if __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
template<typename T>
struct ARTypeConverter {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct ARTypeConverter<__nv_bfloat16> {
using Type = bf168;
};
#endif
// add two 128b data
template<typename T_IN, typename T_COMP>
inline __device__ T_IN add128b(T_IN a, T_IN b);
template<>
inline __device__ uint4 add128b<uint4, uint16_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
template<>
inline __device__ uint4 add128b<uint4, uint32_t>(uint4 a, uint4 b)
{
uint4 c;
c.x = fadd(a.x, b.x);
c.y = fadd(a.y, b.y);
c.z = fadd(a.z, b.z);
c.w = fadd(a.w, b.w);
return c;
}
#ifdef ENABLE_BF16
template<>
inline __device__ bf168 add128b<bf168, __nv_bfloat16>(bf168 a, bf168 b)
{
bf168 c;
c.x = bf16hadd2(a.x, b.x);
c.y = bf16hadd2(a.y, b.y);
c.z = bf16hadd2(a.z, b.z);
c.w = bf16hadd2(a.w, b.w);
return c;
}
#endif
// init 128bits data with 0
template<typename T>
inline __device__ T init_packed_type();
template<>
inline __device__ uint4 init_packed_type()
{
return make_uint4(0u, 0u, 0u, 0u);
}
#ifdef ENABLE_BF16
template<>
inline __device__ bf168 init_packed_type()
{
bf168 val;
uint4& val_u = reinterpret_cast<uint4&>(val);
val_u = make_uint4(0u, 0u, 0u, 0u);
return val;
}
#endif
template<typename T>
static __global__ void oneShotAllReduceKernel(AllReduceParams<T> params)
{
// The block index.
const int bidx = blockIdx.x;
// The thread index with the block.
const int tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
// Packed data type for comms
using PackedType = typename ARTypeConverter<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
// The end of the segment computed by that block.
size_t max_offset = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
// Synchronize the ranks.
volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank];
if (tidx < RANKS_PER_NODE) {
// The 1st block notifies the other ranks.
if (bidx == 0) {
params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag;
}
// Busy-wait until all ranks are ready.
while (barrier_d[tidx] < params.barrier_flag) {}
}
// Make sure we can move on...
__syncthreads();
// The source pointers. Distributed round-robin for the different warps.
const T* src_d[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = params.peer_comm_buffer_ptrs[rank];
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][iter_offset])[0];
}
// Sum the values from the different ranks.
PackedType sums = init_packed_type<PackedType>();
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
sums = add128b<PackedType, T>(sums, vals[ii]);
}
// Store to the destination buffer.
reinterpret_cast<PackedType*>(&params.local_output_buffer_ptr[iter_offset])[0] = sums;
}
}
template<typename T>
static __global__ void twoShotAllReduceKernel(AllReduceParams<T> params)
{
// The block index.
const int bidx = blockIdx.x;
// The thread index with the block.
const int tidx = threadIdx.x;
// The number of elements packed into one for comms
static constexpr int NUM_ELTS = std::is_same<T, uint32_t>::value ? 4 : 8;
// Packed data type for comms
using PackedType = typename ARTypeConverter<T>::Type;
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS + params.rank_offset;
// The end of the segment computed by that block.
size_t max_offset = min(offset + params.elts_per_block, params.elts_total);
// Synchronize the ranks.
volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank];
if (tidx < RANKS_PER_NODE) {
// The 1st block notifies the other ranks.
if (bidx == 0) {
params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag;
}
// Busy-wait until all ranks are ready.
while (barrier_d[tidx] < params.barrier_flag) {}
}
// Make sure we can move on...
__syncthreads();
// The source pointers. Distributed round-robin for the different warps.
T* src_d[RANKS_PER_NODE];
// The destination ranks for round-robin gathering
size_t dst_rank[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
src_d[ii] = params.peer_comm_buffer_ptrs[rank];
dst_rank[ii] = rank;
}
// Each block accumulates the values from the different GPUs on the same node.
for (size_t local_offset = offset; local_offset < max_offset; local_offset += blockDim.x * NUM_ELTS) {
// Iterate over the different ranks/devices on the node to load the values.
PackedType vals[RANKS_PER_NODE];
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
vals[ii] = reinterpret_cast<const PackedType*>(&src_d[ii][local_offset])[0];
}
// Sum the values from the different ranks.
PackedType sums = init_packed_type<PackedType>();
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
sums = add128b<PackedType, T>(sums, vals[ii]);
}
// Store to the local buffer.
reinterpret_cast<PackedType*>(&src_d[0][local_offset])[0] = sums;
}
// sync threads to make sure all block threads have the sums
__syncthreads();
// barreris among the blocks with the same idx (release-acuqire semantics)
if (tidx < RANKS_PER_NODE) {
// The all blocks notifies the other ranks.
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
st_flag_release(params.barrier_flag, params.peer_barrier_ptrs[tidx] + flag_block_offset + params.local_rank);
// Busy-wait until all ranks are ready.
uint32_t rank_barrier = 0;
uint32_t* peer_barrier_d = params.peer_barrier_ptrs[params.local_rank] + flag_block_offset + tidx;
do {
ld_flag_acquire(rank_barrier, peer_barrier_d);
} while (rank_barrier != params.barrier_flag);
}
// sync threads to make sure all other ranks has the final partial results
__syncthreads();
// Gather all needed elts from other intra-node ranks
for (size_t local_offset = offset; local_offset < max_offset; local_offset += blockDim.x * NUM_ELTS) {
#pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) {
// use round-robin gathering from other ranks
int offset_rank = local_offset + (dst_rank[ii] - params.local_rank) * params.elts_per_rank;
reinterpret_cast<PackedType*>(&params.local_output_buffer_ptr[offset_rank])[0] =
reinterpret_cast<PackedType*>(&src_d[dst_rank[ii]][offset_rank])[0];
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void kernelLaunchConfig(
int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo, size_t data_type_bytes)
{
assert(data_type_bytes == 2 || data_type_bytes == 4);
// NOTE: need to support FP16 and FP32
size_t elts_per_thread = 16 / data_type_bytes;
size_t elts_per_warp = (16 * WARP_SIZE) / data_type_bytes;
switch (kernel_algo) {
case 0: { // one stage all reduce algo
assert(elts % elts_per_warp == 0);
if (elts < (elts_per_thread * DEFAULT_BLOCK_SIZE)) { // local reduce
threads_per_block = ((elts + elts_per_warp - 1) / elts_per_warp) * WARP_SIZE;
blocks_per_grid = 1;
}
else { // local reduce
if (elts % (elts_per_thread * threads_per_block) == 0) {
blocks_per_grid =
(elts + elts_per_thread * threads_per_block - 1) / (elts_per_thread * threads_per_block);
// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
int iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
}
else {
int total_threads = elts / elts_per_thread;
blocks_per_grid = 1;
while (total_threads % blocks_per_grid != 0
|| total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
}
}
break;
}
case 1: { // two stage all reduce algo
int total_threads = elts / RANKS_PER_NODE / RANKS_PER_NODE;
assert(elts / RANKS_PER_NODE % RANKS_PER_NODE == 0 && total_threads % WARP_SIZE == 0);
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) {
blocks_per_grid += 1;
}
threads_per_block = total_threads / blocks_per_grid;
// NOTE: need to adjust here
if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) {
int iter_factor = 1;
while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) {
iter_factor += 1;
}
blocks_per_grid /= iter_factor;
}
break;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams<T>& param, cudaStream_t stream)
{
size_t elts_total = param.elts_total;
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
int kernel_algo = 1;
if (elts_total * sizeof(T) <= DEFALUT_ALGO_AR_SIZE_THRESHOLD) {
kernel_algo = 0;
}
kernelLaunchConfig(blocks_per_grid, threads_per_block, elts_total, kernel_algo, sizeof(T));
if (kernel_algo == 0) {
param.elts_per_rank = elts_total;
param.elts_per_block = param.elts_per_rank / blocks_per_grid;
oneShotAllReduceKernel<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}
else {
param.elts_per_rank = param.elts_total / RANKS_PER_NODE;
param.elts_per_block = param.elts_per_rank / blocks_per_grid;
param.rank_offset = param.rank * param.elts_per_rank;
twoShotAllReduceKernel<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
}
}
// Template instantiation
template void invokeOneOrTwoShotAllReduceKernel<uint16_t>(AllReduceParams<uint16_t>& param, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<__nv_bfloat16>& param,
cudaStream_t stream);
#endif
template void invokeOneOrTwoShotAllReduceKernel<uint32_t>(AllReduceParams<uint32_t>& param, cudaStream_t stream);
} // namespace fastertransformer
\ No newline at end of file
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <iostream>
#include "src/fastertransformer/utils/cuda_utils.h"
#define CUSTOM_AR_SIZE_THRESHOLD 50331648
#define MAX_ALL_REDUCE_BLOCKS 24
#define FLAG(a) ((uint32_t)((a) % 0x146))
#define RANKS_PER_NODE 8
#define WARP_SIZE 32
#define DEFAULT_BLOCK_SIZE 1024
#define DEFALUT_ALGO_AR_SIZE_THRESHOLD 393216
namespace fastertransformer {
#ifdef ENABLE_BF16
typedef struct bf168 {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
} bf168;
#endif
template<typename T>
struct AllReduceParams {
size_t elts_total;
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t rank, local_rank, node_id;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs[RANKS_PER_NODE];
T* peer_comm_buffer_ptrs[RANKS_PER_NODE];
T* local_output_buffer_ptr;
};
template<typename T>
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams<T>& param, cudaStream_t stream);
void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo);
} // namespace fastertransformer
\ No newline at end of file
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
template<typename T, typename KERNEL_PARAMS_TYPE>
void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head) {
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
default:
assert(false);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
{
multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
}
#endif
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template<typename T>
struct Multihead_attention_params_base {
// The output buffer. Dimensions B x D.
T* out = nullptr;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const T *q = nullptr, *q_bias = nullptr;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const T *k = nullptr, *k_bias = nullptr;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const T *v = nullptr, *v_bias = nullptr;
// The cache for the Ks. The size must be at least B x L x D.
T* k_cache = nullptr;
// The cache for the Vs. The size must be at least B x L x D.
T* v_cache = nullptr;
// The indirections to use for cache when beam sampling.
const int* cache_indir = nullptr;
// scales
const float* query_weight_output_scale = nullptr;
const float* attention_qk_scale = nullptr;
const float* attention_output_weight_input_scale_inv = nullptr;
// Stride to handle the case when KQV is a single buffer
int stride = 0;
// The batch size.
int batch_size = 0;
// The beam width
int beam_width = 0;
// The sequence length.
int memory_max_len = 0;
// The number of heads (H).
int num_heads = 0;
// The hidden dimension per head (Dh).
int hidden_size_per_head = 0;
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
// The maximum length of input sentences.
int max_input_length = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int timestep = 0;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float inv_sqrt_dh = 0.0f;
// Used when we have some input context like gpt
const int* total_padding_tokens = nullptr;
const bool* masked_tokens = nullptr;
const int* prefix_prompt_lengths = nullptr;
int max_prefix_prompt_length = 0;
const T* relative_attention_bias = nullptr;
int relative_attention_bias_stride = 0;
// The slope per head of linear position bias to attention score (H).
const T* linear_bias_slopes = nullptr;
const T* ia3_key_weights = nullptr;
const T* ia3_value_weights = nullptr;
const int* ia3_tasks = nullptr;
const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr;
int int8_mode = 0;
};
template<typename T>
struct Multihead_attention_params: public Multihead_attention_params_base<T> {
// allows to exist attention eary
bool* finished = nullptr;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
T** k_cache_per_sample = nullptr;
T** v_cache_per_sample = nullptr;
size_t kv_cache_per_sample_offset = 0;
bool k_cache_interleaved = true;
};
template<class T>
using Masked_multihead_attention_params = Multihead_attention_params<T>;
////////////////////////////////////////////////////////////////////////////////////////////////////
void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = params.timestep;
FT_CHECK(params.cache_indir == nullptr);
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template void mmha_launch_kernel<float, 128, 128, Masked_multihead_attention_params<float>>(
const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
template void mmha_launch_kernel<uint16_t, 128, 128, Masked_multihead_attention_params<uint16_t>>(
const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>(
const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mmha_launch_kernel<__nv_fp8_e4m3, 128, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>(
const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif
#undef MMHA_LAUNCH_KERNEL
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include <assert.h>
#include <float.h>
#include <type_traits>
// #define MMHA_USE_HMMA_FOR_REDUCTION
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
// #define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
// Does not seem to improve the accuracy
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#endif
namespace mmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
//
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
// 64, 128 and 256 threads per block.
//
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
// cache buffer helps with memory accesses and contains keys with bias.
//
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
// values for x are chosen to create chunks of 16 bytes.
//
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
//
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
// shared memory.
//
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
// except for the current timestep. The layout of the cache buffer for the values is much simpler
// as it is [B, H, L, Dh].
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh>
struct Qk_vec_m_ {};
template<>
struct Qk_vec_m_<float, 32> {
using Type = float;
};
template<>
struct Qk_vec_m_<float, 64> {
using Type = float2;
};
template<>
struct Qk_vec_m_<float, 128> {
using Type = float4;
};
template<>
struct Qk_vec_m_<float, 256> {
using Type = float4;
};
template<>
struct Qk_vec_m_<uint16_t, 32> {
using Type = uint32_t;
};
template<>
struct Qk_vec_m_<uint16_t, 64> {
using Type = uint32_t;
};
template<>
struct Qk_vec_m_<uint16_t, 128> {
using Type = uint2;
};
template<>
struct Qk_vec_m_<uint16_t, 256> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct Qk_vec_m_<__nv_bfloat16, 32> {
using Type = __nv_bfloat162;
};
template<>
struct Qk_vec_m_<__nv_bfloat16, 64> {
using Type = __nv_bfloat162;
};
template<>
struct Qk_vec_m_<__nv_bfloat16, 128> {
using Type = bf16_4_t;
};
template<>
struct Qk_vec_m_<__nv_bfloat16, 256> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
template<>
struct Qk_vec_m_<__nv_fp8_e4m3, 32> {
using Type = fp8_4_t;
};
template<>
struct Qk_vec_m_<__nv_fp8_e4m3, 64> {
using Type = fp8_4_t;
};
template<>
struct Qk_vec_m_<__nv_fp8_e4m3, 128> {
using Type = fp8_4_t;
};
template<>
struct Qk_vec_m_<__nv_fp8_e4m3, 256> {
using Type = fp8_4_t;
};
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh>
struct Qk_vec_k_ {
using Type = typename Qk_vec_m_<T, Dh>::Type;
};
#ifdef ENABLE_FP8
template<>
struct Qk_vec_k_<__nv_fp8_e4m3, 32> {
using Type = float4;
};
template<>
struct Qk_vec_k_<__nv_fp8_e4m3, 64> {
using Type = float4;
};
template<>
struct Qk_vec_k_<__nv_fp8_e4m3, 128> {
using Type = float4;
};
template<>
struct Qk_vec_k_<__nv_fp8_e4m3, 256> {
using Type = float4;
};
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct K_vec_m_ {};
template<>
struct K_vec_m_<float, 4> {
using Type = float;
};
template<>
struct K_vec_m_<float, 2> {
using Type = float2;
};
template<>
struct K_vec_m_<float, 1> {
using Type = float4;
};
template<>
struct K_vec_m_<uint16_t, 4> {
using Type = uint32_t;
};
template<>
struct K_vec_m_<uint16_t, 2> {
using Type = uint2;
};
template<>
struct K_vec_m_<uint16_t, 1> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct K_vec_m_<__nv_bfloat16, 4> {
using Type = __nv_bfloat162;
};
template<>
struct K_vec_m_<__nv_bfloat16, 2> {
using Type = bf16_4_t;
};
template<>
struct K_vec_m_<__nv_bfloat16, 1> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes
#ifdef ENABLE_FP8
template<>
struct K_vec_m_<__nv_fp8_e4m3, 4> {
using Type = fp8_4_t;
};
template<>
struct K_vec_m_<__nv_fp8_e4m3, 2> {
using Type = fp8_4_t;
}; // Defined for compilation-purpose only, do not use
template<>
struct K_vec_m_<__nv_fp8_e4m3, 1> {
using Type = fp8_4_t;
}; // Defined for compilation-purpose only, do not use
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct K_vec_k_ {
using Type = typename K_vec_m_<T, THREADS_PER_KEY>::Type;
};
#ifdef ENABLE_FP8
template<>
struct K_vec_k_<__nv_fp8_e4m3, 4> {
using Type = float4;
};
template<>
struct K_vec_k_<__nv_fp8_e4m3, 2> {
using Type = float4;
}; // Defined for compilation-purpose only, do not use
template<>
struct K_vec_k_<__nv_fp8_e4m3, 1> {
using Type = float4;
}; // Defined for compilation-purpose only, do not use
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int V_VEC_SIZE>
struct V_vec_m_ {};
template<>
struct V_vec_m_<float, 1> {
using Type = float;
};
template<>
struct V_vec_m_<float, 2> {
using Type = float2;
};
template<>
struct V_vec_m_<float, 4> {
using Type = float4;
};
template<>
struct V_vec_m_<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct V_vec_m_<uint16_t, 4> {
using Type = uint2;
};
template<>
struct V_vec_m_<uint16_t, 8> {
using Type = uint4;
};
#ifdef ENABLE_BF16
template<>
struct V_vec_m_<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template<>
struct V_vec_m_<__nv_bfloat16, 4> {
using Type = bf16_4_t;
};
template<>
struct V_vec_m_<__nv_bfloat16, 8> {
using Type = bf16_8_t;
};
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
template<>
struct V_vec_m_<__nv_fp8_e4m3, 4> {
using Type = fp8_4_t;
};
template<>
struct V_vec_m_<__nv_fp8_e4m3, 8> {
using Type = fp8_4_t;
};
template<>
struct V_vec_m_<__nv_fp8_e4m3, 16> {
using Type = fp8_4_t;
};
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int V_VEC_SIZE>
struct V_vec_k_ {
using Type = typename V_vec_m_<T, V_VEC_SIZE>::Type;
};
#ifdef ENABLE_FP8
template<>
struct V_vec_k_<__nv_fp8_e4m3, 4> {
using Type = float4;
};
template<>
struct V_vec_k_<__nv_fp8_e4m3, 8> {
using Type = float4;
};
template<>
struct V_vec_k_<__nv_fp8_e4m3, 16> {
using Type = float4;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template<typename T>
struct Qk_vec_acum_fp32_ {};
template<>
struct Qk_vec_acum_fp32_<float> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<float2> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<float4> {
using Type = float4;
};
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
template<>
struct Qk_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<uint2> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct Qk_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct Qk_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct Qk_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#ifdef ENABLE_FP8
// template<>
// struct Qk_vec_acum_fp32_<fp8_2_t> {
// using Type = float2;
// };
template<>
struct Qk_vec_acum_fp32_<fp8_4_t> {
using Type = Float4_;
};
// template<>
// struct Qk_vec_acum_fp32_<fp8_8_t> {
// using Type = Float4_;
// };
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct K_vec_acum_fp32_ {};
template<>
struct K_vec_acum_fp32_<float> {
using Type = float;
};
template<>
struct K_vec_acum_fp32_<float2> {
using Type = float2;
};
template<>
struct K_vec_acum_fp32_<float4> {
using Type = float4;
};
template<>
struct K_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
template<>
struct K_vec_acum_fp32_<uint2> {
using Type = Float4_;
};
template<>
struct K_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
template<>
struct K_vec_acum_fp32_<__nv_bfloat16> {
using Type = float;
};
template<>
struct K_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct K_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct K_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#ifdef ENABLE_FP8
// template<>
// struct K_vec_acum_fp32_<fp8_2_t> {
// using Type = float2;
// };
template<>
struct K_vec_acum_fp32_<fp8_4_t> {
using Type = Float4_;
};
// template<>
// struct K_vec_acum_fp32_<fp8_8_t> {
// using Type = Float4_;
// };
#endif // ENABLE_FP8
#endif // MMHA_USE_FP32_ACUM_FOR_FMA
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template<typename T>
struct V_vec_acum_fp32_ {};
template<>
struct V_vec_acum_fp32_<float> {
using Type = float;
};
template<>
struct V_vec_acum_fp32_<float2> {
using Type = float2;
};
template<>
struct V_vec_acum_fp32_<float4> {
using Type = float4;
};
template<>
struct V_vec_acum_fp32_<uint32_t> {
using Type = float2;
};
template<>
struct V_vec_acum_fp32_<uint2> {
using Type = Float4_;
};
template<>
struct V_vec_acum_fp32_<uint4> {
using Type = Float8_;
};
#ifdef ENABLE_BF16
template<>
struct V_vec_acum_fp32_<__nv_bfloat162> {
using Type = float2;
};
template<>
struct V_vec_acum_fp32_<bf16_4_t> {
using Type = Float4_;
};
template<>
struct V_vec_acum_fp32_<bf16_8_t> {
using Type = Float8_;
};
#endif // ENABLE_BF16
#ifdef ENABLE_FP8
// template<>
// struct V_vec_acum_fp32_<fp8_2_t> {
// using Type = float2;
// };
template<>
struct V_vec_acum_fp32_<fp8_4_t> {
using Type = Float4_;
};
// template<>
// struct V_vec_acum_fp32_<fp8_8_t> {
// using Type = Float4_;
// };
#endif // ENABLE_FP8
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
#ifdef ENABLE_FP8
// fp8_t
template<>
__inline__ __device__ float vec_conversion<float, __nv_fp8_e4m3>(const __nv_fp8_e4m3& a)
{
return float(a);
}
template<>
__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a)
{
return __nv_fp8_e4m3(a);
}
// fp8_2_t
template<>
__inline__ __device__ float2 vec_conversion<float2, fp8_2_t>(const fp8_2_t& a)
{
return float2(a);
}
template<>
__inline__ __device__ fp8_2_t vec_conversion<fp8_2_t, float2>(const float2& a)
{
return fp8_2_t(a);
}
// fp8_4_t
template<>
__inline__ __device__ float4 vec_conversion<float4, fp8_4_t>(const fp8_4_t& a)
{
return float4(a);
}
template<>
__inline__ __device__ fp8_4_t vec_conversion<fp8_4_t, float4>(const float4& a)
{
return fp8_4_t(a);
}
#endif // ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
{
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
#else
using K_vec_acum = K_vec;
#endif
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct Qk_dot {
template<typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
{
float4 c;
float zero = 0.f;
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
#else
using K_vec_acum = uint32_t;
#endif
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Qk_dot<uint16_t, 4> {
template<int N>
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return qk_hmma_dot_(q, k);
#else
return qk_dot_<4>(q, k);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float& dst, float src)
{
dst = src;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint16_t& dst, float src)
{
dst = float_to_half(src);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint32_t& dst, float2 src)
{
dst = float2_to_half2(src);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
{
dst = __float2bfloat16(src);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst = __float22bfloat162_rn(src);
#else
dst = __floats2bfloat162_rn(src.x, src.y);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint2& dst, Float4_ src)
{
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint2& dst, float4 src)
{
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(uint4& dst, Float8_ src)
{
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#else
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
{
convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#else
dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_FP8
inline __device__ void convert_from_float(fp8_4_t& dst, float4 src)
{
dst = fp8_4_t(src);
}
inline __device__ void convert_from_float(fp8_2_t& dst, float2 src)
{
dst = fp8_2_t(src);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float2& dst, float2 src)
{
dst = src;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void convert_from_float(float4& dst, float4 src)
{
dst = src;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float convert_to_float(float4 u)
{
return u.x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float convert_to_float(uint4 u)
{
float2 tmp = half2_to_float2(u.x);
return tmp.x;
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float cast_to_float(float u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(float2 u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 cast_to_float(float4 u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(Float4_ u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(Float8_ u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 cast_to_float(uint32_t u)
{
return half2_to_float2(u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ cast_to_float(uint2 u)
{
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ cast_to_float(uint4 u)
{
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float float_from_int8(int8_t u)
{
return u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 float_from_int8(int16_t u)
{
union {
int16_t int16;
int8_t int8[2];
};
int16 = u;
return make_float2(int8[0], int8[1]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 float_from_int8(int32_t u)
{
union {
int32_t int32;
int8_t int8[4];
};
int32 = u;
return make_float4(int8[0], int8[1], int8[2], int8[3]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// clang-format off
inline __device__ Float8_ float_from_int8(int64_t u)
{
union {
int64_t int64;
int16_t int16[4];
};
int64 = u;
return Float8_ {float_from_int8(int16[0]),
float_from_int8(int16[1]),
float_from_int8(int16[2]),
float_from_int8(int16[3])};
}
// clang-format on
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int8_t cast_to_int8(float val)
{
union {
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int32_t cast_to_int8(float4 val)
{
union {
int8_t int8[4];
int32_t int32;
};
int8[0] = cast_to_int8(val.x);
int8[1] = cast_to_int8(val.y);
int8[2] = cast_to_int8(val.z);
int8[3] = cast_to_int8(val.w);
return int32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int64_t cast_to_int8(Float8_ val)
{
union {
int8_t int8[8];
int64_t int64;
};
int8[0] = cast_to_int8(val.x.x);
int8[1] = cast_to_int8(val.x.y);
int8[2] = cast_to_int8(val.y.x);
int8[3] = cast_to_int8(val.y.y);
int8[4] = cast_to_int8(val.z.x);
int8[5] = cast_to_int8(val.z.y);
int8[6] = cast_to_int8(val.w.x);
int8[7] = cast_to_int8(val.w.y);
return int64;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ __host__ T div_up(T m, T n)
{
return (m + n - 1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct kernel_type_t {
using Type = T;
};
#ifdef ENABLE_FP8
template<>
struct kernel_type_t<__nv_fp8_e4m3> {
using Type = float;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline size_t
smem_size_in_bytes(const Multihead_attention_params<T>& params, int threads_per_value, int threads_per_block)
{
using Tk = typename kernel_type_t<T>::Type;
// The amount of shared memory needed to store the Q*K^T values in float.
const int max_timesteps = min(params.timestep, params.memory_max_len);
size_t qk_sz = div_up(max_timesteps + 1, 4) * 16;
// The extra memory needed if we are not using floats for the final logits.
size_t logits_sz = 0;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(Tk) != 4) {
// TDOD
logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk);
}
#endif
// The total size needed during softmax.
size_t softmax_sz = qk_sz + logits_sz;
// The number of partial rows to reduce in the final reduction.
int rows_per_red = threads_per_block / threads_per_value;
// The amount of storage needed to finalize the outputs.
size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2;
// The max.
return max(softmax_sz, red_sz);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ constexpr uint32_t shfl_mask(int threads)
{
return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, // The type of the inputs. Supported types: float and half.
int Dh, // The hidden dimension per head.
int Dh_MAX,
int THREADS_PER_KEY, // The number of threads per key.
int THREADS_PER_VALUE, // The number of threads per value.
int THREADS_PER_BLOCK, // The number of threads in a threadblock.
bool HAS_BEAMS>
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> params)
{
using Tk = typename kernel_type_t<T>::Type;
#ifdef ENABLE_FP8
// FP8 MHA Scales
constexpr bool FP8_MHA_KERNEL = std::is_same<T, __nv_fp8_e4m3>::value;
#else
constexpr bool FP8_MHA_KERNEL = false;
#endif
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
// The size of a warp.
constexpr int WARP_SIZE = 32;
// The number of warps in a threadblock.
constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
extern __shared__ char smem_[];
// The shared memory for the Q*K^T values and partial logits in softmax.
float* qk_smem = reinterpret_cast<float*>(smem_);
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
char* logits_smem_ = smem_;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if (sizeof(Tk) != 4) {
// TODO - change to tlength
const int max_timesteps = min(params.timestep, params.memory_max_len);
logits_smem_ += div_up(max_timesteps + 1, 4) * 16;
}
Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
#else
float* logits_smem = reinterpret_cast<float*>(logits_smem_);
#endif
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
Tk* out_smem = reinterpret_cast<Tk*>(smem_);
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__ float red_smem[WARPS_PER_BLOCK * 2];
// A vector of Q or K elements for the current timestep.
using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type; // with kernel-used precision
using Qk_vec_m = typename Qk_vec_m_<T, Dh_MAX>::Type; // with memory-used precision
// Use alignment for safely casting the shared buffers as Qk_vec_k.
// Shared memory to store Q inputs.
__shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX];
// The number of elements per vector.
constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
// We will use block wide reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
// The number of vectors per warp.
constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.
// The number of elements in a chunk of 16B (that's the x in the above formula).
constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
// The number of K vectors in 16B.
constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m);
// The batch/beam idx
const int bi = blockIdx.y;
if (params.finished != nullptr && params.finished[bi] == true) {
return;
}
// The beam idx
const int beami = bi % params.beam_width;
// The "beam-aware" batch idx
const int bbi = bi / params.beam_width;
// The head.
const int hi = blockIdx.x;
// Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi;
// Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi;
// The thread in the block.
const int tidx = threadIdx.x;
constexpr bool handle_kv = true;
// While doing the product Q*K^T for the different keys we track the max.
float qk_max = -FLT_MAX;
float qk = 0.0F;
int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const size_t bi_seq_len_offset = bi * params.memory_max_len;
const int tlength = params.length_per_sample[bi] + params.max_prefix_prompt_length;
const int first_step = max(0, tlength + 1 - params.memory_max_len);
const int tlength_circ = tlength % params.memory_max_len;
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP;
// The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
// The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
// Trigger the loads from the Q and K buffers.
Qk_vec_k q;
zero(q);
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
const auto q_scaling = params.qkv_scale_out[0];
const auto q_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
}
else {
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset]));
}
}
Qk_vec_k k;
zero(k);
{
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
const auto k_scaling = params.qkv_scale_out[1];
const auto k_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
}
else {
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :
k;
}
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias;
zero(q_bias);
q_bias =
(!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[qk_bias_offset])) :
q_bias;
Qk_vec_k k_bias;
zero(k_bias);
if (handle_kv) {
k_bias =
!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[qk_bias_offset])) :
k_bias;
}
// Computes the Q/K values with bias.
q = add(q, q_bias);
if (handle_kv) {
k = add(k, k_bias);
}
if (do_ia3 && !is_masked) {
k = mul<Qk_vec_k, Qk_vec_k, Qk_vec_k>(
k,
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(
&params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])));
}
// Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0) {
if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
}
else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
}
}
if (!is_masked) {
// Store the Q values to shared memory.
*reinterpret_cast<Qk_vec_k*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
// Write the K values to the global memory cache.
//
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
// system. We designed it this way as it allows much better memory loads (and there are many
// more loads) + the stores are really "write and forget" since we won't need the ack before
// the end of the kernel. There's plenty of time for the transactions to complete.
// The 16B chunk written by the thread.
int co = tidx / QK_VECS_IN_16B;
// The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
if (handle_kv) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
if (!params.k_cache_per_sample) {
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci;
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
else {
int offset;
if (params.k_cache_interleaved) {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh
+ co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
}
else {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci;
}
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
}
}
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec_k>::Type;
#else
using Qk_vec_acum = Qk_vec_k;
#endif
qk = dot<Qk_vec_acum, Qk_vec_k>(q, k);
if (QK_VECS_PER_WARP <= WARP_SIZE) {
#pragma unroll
for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
}
}
}
if (QK_VECS_PER_WARP > WARP_SIZE) {
constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
}
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
if (tidx == 0) {
// Normalize qk.
qk *= params.inv_sqrt_dh;
if (params.relative_attention_bias != nullptr) {
qk = add(qk,
params.relative_attention_bias[hi * params.relative_attention_bias_stride
* params.relative_attention_bias_stride
+ (tlength - padd_len) * params.relative_attention_bias_stride
+ (tlength - padd_len)]);
}
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
qk_max = qk;
qk_smem[tlength - first_step] = qk;
// qk_smem[params.timestep] = qk;
}
// Make sure the data is in shared memory.
__syncthreads();
// The type of queries and keys for the math in the Q*K^T product.
using K_vec_k = typename K_vec_k_<T, THREADS_PER_KEY>::Type;
using K_vec_m = typename K_vec_m_<T, THREADS_PER_KEY>::Type;
// The number of elements per vector.
constexpr int K_VEC_SIZE = sizeof(K_vec_m) / sizeof(T);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
// The number of elements per thread.
constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
// The number of vectors per thread.
constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
// The position the first key loaded by each thread from the cache buffer (for this B * H).
int ko = tidx / THREADS_PER_KEY;
// The position of the thread in the chunk of keys.
int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
// Load the Q values from shared memory. The values are reused during the loop on K.
K_vec_k q_vec[K_VECS_PER_THREAD];
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
q_vec[ii] = *reinterpret_cast<const K_vec_k*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
}
// The number of timesteps loaded per iteration.
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
// The number of keys per warp.
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer.
T* k_cache =
params.k_cache_per_sample ?
(params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
T* k_cache_batch = k_cache;
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
// prefix prompt length if has
const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr;
for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
const int ti_circ = ti % params.memory_max_len;
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
// The keys loaded from the key cache.
K_vec_k k[K_VECS_PER_THREAD];
K_vec_k k_vec_zero;
zero(k_vec_zero);
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj =
params.k_cache_interleaved ? ii * params.memory_max_len + ti_circ : ti_circ * Dh / QK_ELTS_IN_16B + ii;
// if( ti < params.timestep ) {
const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
if (ti < tlength) {
if (!within_bounds) {
k[ii] = k_vec_zero;
}
else {
if (HAS_BEAMS) {
const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
}
}
}
// Perform the dot product and normalize qk.
//
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
// Store the product to shared memory. There's one qk value per timestep. Update the max.
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
if (params.relative_attention_bias != nullptr) {
qk = add(qk,
params.relative_attention_bias[hi * params.relative_attention_bias_stride
* params.relative_attention_bias_stride
+ tlength * params.relative_attention_bias_stride + ti]);
}
if (params.linear_bias_slopes != nullptr) {
// Apply the linear position bias: (ki - qi) * slope[hi].
// The padding token locates between the input context and the generated tokens.
// We need to remove the number of padding tokens in the distance computation.
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
// token: i i i i p p p o o o where i=input, p=pad, o=output.
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
}
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
qk_smem[ti - first_step] = qk;
}
}
// Perform the final reduction to compute the max inside each warp.
//
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
// group so it's not needed to run the reduction inside the group (again).
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Decompose the thread index into warp and lane.
const int warp = tidx / WARP_SIZE;
const int lane = tidx % WARP_SIZE;
// The warp leader writes the max to shared memory.
if (lane == 0) {
red_smem[warp] = qk_max;
}
// Make sure the products are in shared memory.
__syncthreads();
// The warps finalize the reduction.
qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast to all the threads in the warp.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Compute the logits and start the sum.
float sum = 0.f;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
#ifdef FP8_MHA
float logit = 0.f;
if (FP8_MHA_KERNEL) {
logit = is_mask ? 0.f :
__expf((qk_smem[ti - first_step] - qk_max) * params.query_weight_output_scale[0]
* params.query_weight_output_scale[0]);
}
else {
logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
}
#else
float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
#endif
sum += logit;
qk_smem[ti - first_step] = logit;
}
// Compute the sum.
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
// Normalize the logits.
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
float logit = qk_smem[ti - first_step] * inv_sum;
convert_from_float(logits_smem[ti - first_step], logit);
}
// Put Values part below so we leverage __syncthreads
// from the previous step
// The number of elements per vector.
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
// A vector of V elements for the current timestep.
using V_vec_k = typename V_vec_k_<T, V_VEC_SIZE>::Type;
using V_vec_m = typename V_vec_m_<T, V_VEC_SIZE>::Type;
// The value computed by this thread.
int vo = tidx / THREADS_PER_VALUE;
// The hidden dimensions computed by this particular thread.
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer.
T* v_cache =
params.v_cache_per_sample ?
(params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
T* v_cache_batch = v_cache;
// The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
zero(v_bias);
// if( vo == params.timestep % V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
if (handle_kv) {
if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) {
v_bias = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&params.v_bias[hi * Dh + vi]));
}
}
}
}
// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads();
// Values continued
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using V_vec_acum = typename V_vec_acum_fp32_<V_vec_k>::Type;
#else
using V_vec_acum = V_vec_k;
#endif
// The partial outputs computed by each thread.
V_vec_acum out;
zero(out);
// Loop over the timesteps to compute the partial outputs.
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
// Separate the ti < memory_max_len and ti > memory_max_len
// to prevent ti % memory_len when ti < memory_len, and
// the compiler cannot optimize the codes automatically.
const int min_length = min(tlength, params.memory_max_len);
for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) {
// Fetch offset based on cache_indir when beam sampling
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache.
V_vec_k v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti - first_step];
out = fma(logit, cast_to_float(v), out);
#else // MMHA_USE_FP32_ACUM_FOR_LOGITS
#ifdef FP8_MHA
Tk logit;
if (FP8_MHA_KERNEL) {
// NOTE: fake quantization
// logit = vec_conversion<Tk, Tquant>(vec_conversion<Tquant, Tk>(mul<Tk, float, Tk>(1.0f /
// params.attention_qk_scale[0], logits_smem[ti])));
logit = logits_smem[ti - first_step];
}
else {
logit = logits_smem[ti - first_step];
}
out = fma(logit, v, out);
#else // FP8_MHA
Tk logit = logits_smem[ti - first_step];
out = fma(logit, v, out);
#endif // FP8_MHA
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
}
for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
if (ti < params.memory_max_len) {
// handled by previous loop
continue;
}
const int ti_circ = ti % params.memory_max_len;
// Fetch offset based on cache_indir when beam sampling
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache.
V_vec_k v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti - first_step];
out = fma(logit, cast_to_float(v), out);
#else // MMHA_USE_FP32_ACUM_FOR_LOGITS
#ifdef FP8_MHA
Tk logit;
if (FP8_MHA_KERNEL) {
// NOTE: fake quantization
// logit = vec_conversion<Tk, Tquant>(vec_conversion<Tquant, Tk>(mul<Tk, float, Tk>(1.0f /
// params.attention_qk_scale[0], logits_smem[ti])));
logit = logits_smem[ti - first_step];
}
else {
logit = logits_smem[ti - first_step];
}
out = fma(logit, v, out);
#else // FP8_MHA
Tk logit = logits_smem[ti - first_step];
out = fma(logit, v, out);
#endif // FP8_MHA
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
}
}
// One group of threads computes the product(s) for the current timestep.
// if( vo == params.timestep % V_PER_ITER ) {
if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
V_vec_k v;
// Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi;
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<V_vec_k>::value>::type;
const auto v_scaling = params.qkv_scale_out[2];
const auto v_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
}
else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
}
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
// Compute the V values with bias.
if (handle_kv) {
v = add(v, v_bias);
if (do_ia3) {
v = mul<V_vec_k, V_vec_k, V_vec_k>(
v,
*reinterpret_cast<const V_vec_k*>(
&params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
}
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
}
// Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
#else // MMHA_USE_FP32_ACUM_FOR_LOGITS
// out = fma(logits_smem[params.timestep], v, out);
#ifdef FP8_MHA
Tk logit;
if (FP8_MHA_KERNEL) {
// NOTE: fake quantization
// logit = mul<Tk, float, Tk>(1.0f / params.attention_qk_scale[0], logits_smem[tlength]);
logit = logits_smem[tlength - first_step];
}
else {
logit = logits_smem[tlength - first_step];
}
out = fma(logit, v, out);
#else // FP8_MHA
out = fma(logits_smem[tlength - first_step], v, out);
#endif // FP8_MHA
#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS
}
// Make sure we can start writing to shared memory.
__syncthreads();
// Run the final reduction amongst the different groups computing different partial outputs.
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
// The midpoint in the number of active groups.
int midpoint = active_groups / 2;
// The upper part of active threads store to shared memory.
if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float(*reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
#else
*reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
}
__syncthreads();
// The bottom warps update their values.
if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
out = add(*reinterpret_cast<const V_vec_k*>(&out_smem[vo * Dh + vi]), out);
}
__syncthreads();
}
}
// Output the final values.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
if (FP8_MHA_KERNEL) {
#ifdef FP8_MHA
// float result_scale = params.attention_qk_scale[0] * params.query_weight_output_scale[0] *
// params.attention_output_weight_input_scale_inv[0];
float result_scale =
params.query_weight_output_scale[0] * params.attention_output_weight_input_scale_inv[0];
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]),
mul<V_vec_acum, float, V_vec_acum>(result_scale, out));
#endif // FP8_MHA
}
else if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
cast_to_int8(out);
}
else {
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out);
}
#else // MMHA_USE_FP32_ACUM_FOR_OUT
// TODO: support int8_mode?
*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]) = vec_conversion<V_vec_m, V_vec_acum>(out);
#endif // MMHA_USE_FP32_ACUM_FOR_OUT
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace mmha
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Dh_MAX>
struct threads_per_value_t {
static const int value = Dh_MAX * sizeof(T) / 16;
};
#ifdef ENABLE_FP8
template<int Dh_MAX>
struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> {
static const int value = Dh_MAX * 4 / 16; // DEBUG: float v
};
#endif
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include <stdint.h>
using namespace fastertransformer;
namespace mmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Float4_ {
float2 x;
float2 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
#endif
#ifdef ENABLE_FP8
using fp8_2_t = __nv_fp8x2_e4m3;
using fp8_4_t = __nv_fp8x4_e4m3;
struct fp8_8_t {
__nv_fp8_e4m3 x;
__nv_fp8_e4m3 y;
__nv_fp8_e4m3 z;
__nv_fp8_e4m3 w;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct num_elems;
template<>
struct num_elems<float> {
static constexpr int value = 1;
};
template<>
struct num_elems<float2> {
static constexpr int value = 2;
};
template<>
struct num_elems<float4> {
static constexpr int value = 4;
};
template<>
struct num_elems<Float4_> {
static constexpr int value = 4;
};
template<>
struct num_elems<Float8_> {
static constexpr int value = 8;
};
template<>
struct num_elems<uint32_t> {
static constexpr int value = 2;
};
template<>
struct num_elems<uint2> {
static constexpr int value = 4;
};
template<>
struct num_elems<uint4> {
static constexpr int value = 8;
};
#ifdef ENABLE_BF16
template<>
struct num_elems<__nv_bfloat162> {
static constexpr int value = 2;
};
template<>
struct num_elems<bf16_4_t> {
static constexpr int value = 4;
};
template<>
struct num_elems<bf16_8_t> {
static constexpr int value = 8;
};
#endif
#ifdef ENABLE_FP8
template<>
struct num_elems<__nv_fp8_e4m3> {
static constexpr int value = 1;
};
template<>
struct num_elems<fp8_2_t> {
static constexpr int value = 2;
};
template<>
struct num_elems<fp8_4_t> {
static constexpr int value = 4;
};
template<>
struct num_elems<fp8_8_t> {
static constexpr int value = 8;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int N>
struct packed_type;
template<typename T>
struct packed_type<T, 1> {
using type = T;
};
template<>
struct packed_type<int8_t, 2> {
using type = int16_t;
};
template<>
struct packed_type<int8_t, 4> {
using type = int32_t;
};
template<>
struct packed_type<int8_t, 8> {
using type = int64_t;
};
template<>
struct packed_type<float, 2> {
using type = float2;
};
template<>
struct packed_type<float, 4> {
using type = float4;
};
template<>
struct packed_type<float, 8> {
using type = Float8_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, float b)
{
return a + b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 add(float2 a, float2 b)
{
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 add(float4 a, float4 b)
{
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
{
return a + b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t add(uint16_t a, uint16_t b)
{
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t add(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint2 add(uint2 a, uint2 b)
{
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint4 add(uint4 a, uint4 b)
{
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint16_t float_to_half(float f)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
float zero = 0.f;
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#endif
return tmp.u16[0];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t float2_to_half2(float2 f)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float half_to_float(uint16_t h)
{
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 half2_to_float2(uint32_t v)
{
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float add(float a, uint16_t b)
{
return a + half_to_float(b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float add(float a, __nv_bfloat16 b)
{
return a + __bfloat162float(b);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_FP8
inline __device__ float add(float a, __nv_fp8_e4m3 b)
{
return a + (float)(b);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 add(uint32_t a, float2 fb)
{
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ add(uint2 a, Float4_ fb)
{
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ add(uint4 a, Float8_ fb)
{
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t h0_h0(uint16_t a)
{
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(float a, float b, float c)
{
return a * b + c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(float2 a, float2 b, float2 c)
{
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(float a, float2 b, float2 c)
{
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 fma(float4 a, float4 b, float4 c)
{
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 fma(float a, float4 b, float4 c)
{
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 fma(float a, float4 b, Float4_ c)
{
float4 d;
d.x = fma(a, b.x, c.x.x);
d.y = fma(a, b.y, c.x.y);
d.z = fma(a, b.z, c.y.x);
d.w = fma(a, b.w, c.y.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
{
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
{
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
{
float2 fa = bf1622float2(a);
return add(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
{
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
{
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
{
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
{
return fma(h0_h0(a), b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
{
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
{
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
{
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
{
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(uint16_t a, uint16_t b, float fc)
{
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
{
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
{
return fma(h0_h0(a), b, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
{
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
{
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
{
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
{
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(a, b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
{
return bf16hfma2(bf162bf162(a), b, c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
{
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
{
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
{
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
{
return fma(bf162bf162(a), b, fc);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
{
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
{
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b)
{
return Acc{}; // for compile
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul<float, float>(float a, float b)
{
return a * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float2 a, float2 b)
{
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(float a, float2 b)
{
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float4 a, float4 b)
{
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float4 mul(float a, float4 b)
{
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(float a, Float8_ b)
{
Float8_ c;
c.x = mul<float2, float, float2>(a, b.x);
c.y = mul<float2, float, float2>(a, b.y);
c.z = mul<float2, float, float2>(a, b.z);
c.w = mul<float2, float, float2>(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b)
{
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b)
{
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b)
{
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint2 mul(uint2 a, uint2 b)
{
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b)
{
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(uint4 a, uint4 b)
{
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b)
{
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(uint16_t a, uint16_t b)
{
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(uint16_t a, float b)
{
return half_to_float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b)
{
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b)
{
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b)
{
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b)
{
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b)
{
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b)
{
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __hmul(a, b);
#else
return bf16hmul(a, b);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
return bf16hmul2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
{
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
{
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
{
float fa = (float)a;
float fb = (float)b;
return fa * fb;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float mul(__nv_bfloat16 a, float b)
{
return __bfloat162float(a) * b;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
{
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
{
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
{
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
{
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
{
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float v)
{
return v;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float2 v)
{
return v.x + v.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(float4 v)
{
return v.x + v.y + v.z + v.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline __device__ float sum(__nv_bfloat162 v)
{
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_4_t v)
{
return sum(v.x) + sum(v.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(bf16_8_t v)
{
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint16_t v)
{
return half_to_float(v);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint32_t v)
{
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint2 v)
{
uint32_t c = add(v.x, v.y);
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(uint4 v)
{
#if 1
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
#else
uint32_t c = add(v.x, v.y);
uint32_t d = add(v.z, v.w);
c = add(c, d);
#endif
return sum(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float4_ v)
{
return v.x.x + v.x.y + v.y.x + v.y.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float sum(Float8_ v)
{
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float dot(T a, T b)
{
return sum(mul<T, T, T>(a, b));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename A, typename T>
inline __device__ float dot(T a, T b)
{
return sum(mul<A, T, T>(a, b));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void zero(uint16_t& dst)
{
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void zero(T& dst)
{
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step)
{
const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)};
}
inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
{
float2 rot_v;
rot_v.x = coef.x * v.x - coef.y * v.y;
rot_v.y = coef.x * v.y + coef.y * v.x;
return rot_v;
}
inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
{
float2 fv = half2_to_float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return float2_to_half2(rot_fv);
}
#ifdef ENABLE_BF16
inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
{
float2 fv = bf1622float2(v);
float2 rot_fv = rotary_embedding_transform(fv, coef);
return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
}
#endif
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step)
{
return;
}
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step)
{
return;
}
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#endif // ENABLE_BF16
template<typename Vec_T, typename T>
__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
template<>
__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = smem[transpose_idx];
tmp.u16[1] = smem[smem_pitch + transpose_idx];
vec = tmp.u32;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp_1, tmp_2;
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
union {
uint2 u32x2;
uint16_t u16[4];
} tmp_3;
tmp_3.u16[0] = tmp_1.u16[0];
tmp_3.u16[1] = tmp_2.u16[0];
tmp_3.u16[2] = tmp_1.u16[1];
tmp_3.u16[3] = tmp_2.u16[1];
vec = tmp_3.u32x2;
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
uint16_t u16[4];
} tmp_1, tmp_2;
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
union {
uint4 u32x4;
uint16_t u16[8];
} tmp_3;
tmp_3.u16[0] = tmp_1.u16[0];
tmp_3.u16[1] = tmp_2.u16[0];
tmp_3.u16[2] = tmp_1.u16[1];
tmp_3.u16[3] = tmp_2.u16[1];
tmp_3.u16[4] = tmp_1.u16[2];
tmp_3.u16[5] = tmp_2.u16[2];
tmp_3.u16[6] = tmp_1.u16[3];
tmp_3.u16[7] = tmp_2.u16[3];
vec = tmp_3.u32x4;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
__nv_bfloat16 bf16[2];
} tmp_1, tmp_2;
tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
}
template<>
__device__ __inline__ void
vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
__nv_bfloat16 bf16[4];
} tmp_1, tmp_2;
tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
}
#endif // ENABLE_BF16
template<>
__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.z = smem[transpose_idx + 1];
vec.y = smem[smem_pitch + transpose_idx];
vec.w = smem[smem_pitch + transpose_idx + 1];
}
template<>
__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u16[0] = smem[transpose_idx];
tmp.u16[1] = smem[smem_pitch + transpose_idx];
vec = tmp.u32;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.y = smem[smem_pitch + transpose_idx];
}
#endif
template<>
__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
{
vec.x = smem[transpose_idx];
vec.y = smem[smem_pitch + transpose_idx];
}
template<typename Vec_T, typename T>
__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
template<>
__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
{
return;
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
#endif
#ifdef ENABLE_FP8
template<>
__device__ __inline__ void vec_from_smem_transpose(float4& vec, __nv_fp8_e4m3* smem, int transpose_idx, int smem_pitch)
{
// TODO
printf("[ERROR] still no have implementation for vec_from_smem_transpose under __nv_fp8_e4m3 \n");
}
#endif // ENABLE_FP8
template<>
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint64_t u64;
uint16_t u16[4];
} tmp_1, tmp_2;
union {
uint4 u32x4;
uint16_t u16[8];
} tmp_3;
tmp_3.u32x4 = vec;
tmp_1.u16[0] = tmp_3.u16[0];
tmp_2.u16[0] = tmp_3.u16[1];
tmp_1.u16[1] = tmp_3.u16[2];
tmp_2.u16[1] = tmp_3.u16[3];
tmp_1.u16[2] = tmp_3.u16[4];
tmp_2.u16[2] = tmp_3.u16[5];
tmp_1.u16[3] = tmp_3.u16[6];
tmp_2.u16[3] = tmp_3.u16[7];
*reinterpret_cast<uint64_t*>(&smem[transpose_idx]) = tmp_1.u64;
*reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp_1, tmp_2;
union {
uint2 u32x2;
uint16_t u16[4];
} tmp_3;
tmp_3.u32x2 = vec;
tmp_1.u16[0] = tmp_3.u16[0];
tmp_2.u16[0] = tmp_3.u16[1];
tmp_1.u16[1] = tmp_3.u16[2];
tmp_2.u16[1] = tmp_3.u16[3];
*reinterpret_cast<uint32_t*>(&smem[transpose_idx]) = tmp_1.u32;
*reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = vec;
smem[transpose_idx] = tmp.u16[0];
smem[smem_pitch + transpose_idx] = tmp.u16[1];
}
template<>
__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[transpose_idx + 1] = vec.z;
smem[smem_pitch + transpose_idx] = vec.y;
smem[smem_pitch + transpose_idx + 1] = vec.w;
}
template<>
__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
{
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32 = vec;
smem[transpose_idx] = tmp.u16[0];
smem[smem_pitch + transpose_idx] = tmp.u16[1];
}
#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
#endif
template<>
__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
{
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}
#ifdef ENABLE_FP8
template<>
__device__ __inline__ void
write_smem_transpose(const float4& vec, __nv_fp8_e4m3* smem, int transpose_idx, int smem_pitch)
{
printf("[ERROR] still no have implementation for vec_from_smem_transpose under __nv_fp8_e4m3 \n");
}
#endif // ENABLE_FP8
} // namespace mmha
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/decoding_kernels.h"
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/utils/cuda_utils.h"
namespace fastertransformer {
// static const float HALF_FLT_MAX = 65504.F;
template<typename T>
__global__ void decodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
T* cum_log_probs,
const int* sentence_ids,
const int batch_size,
const int beam_width,
const int max_input_length)
{
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? (T)HALF_FLT_MAX : (T)1e20f; // BF16 and FP32 have the same dynamic range
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
index += blockDim.x * gridDim.x) {
finished[index] = false;
sequence_length[index] = max_input_length;
if (word_ids != nullptr) {
word_ids[index] = sentence_ids[index / beam_width];
}
cum_log_probs[index] = (index % beam_width == 0) ? (T)0.0f : (T)-MAX_T_VAL;
}
}
template<typename T>
void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
T* cum_log_probs,
const int* sentence_ids,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream)
{
dim3 grid((int)ceil(batch_size * beam_width * 1.0 / 256));
dim3 block(256);
decodingInitialize<T><<<grid, block, 0, stream>>>(
finished, sequence_length, word_ids, cum_log_probs, sentence_ids, batch_size, beam_width, max_input_length);
}
template void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
float* cum_log_probs,
const int* sentence_ids,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream);
template void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
half* cum_log_probs,
const int* sentence_ids,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
__nv_bfloat16* cum_log_probs,
const int* sentence_ids,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream);
#endif
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T>
__global__ void embeddingLookupPosEncoding(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* padding_count,
const int* input_lengths,
const int local_token_num,
const int64_t hidden_units,
const int step,
const int max_input_length,
const int token_num,
const int ite,
const T scale)
{
// 1. lookup from embedding table
// 2. multiply scale
// 3. add the position encoding
const int id_offset = step * token_num + ite * local_token_num;
const bool use_padding_count = padding_count != nullptr;
const bool use_input_len = input_lengths != nullptr;
for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
index += blockDim.x * gridDim.x) {
const int row_index = index / hidden_units;
const int col_index = index % hidden_units;
int step_offset = step;
if (use_padding_count) {
step_offset -= padding_count[row_index];
}
else if (use_input_len) {
step_offset -= max_input_length - input_lengths[row_index];
}
step_offset *= hidden_units;
T val = embedding_table[all_ids[id_offset + row_index] * hidden_units + col_index] * scale;
val = val + position_encoding[step_offset + col_index];
from_tensor[index] = val;
}
}
// No absolute position embedding
// PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts
template<typename T, int PROMPT_SRC>
__global__ void embeddingLookup(T* from_tensor,
const T* embedding_table,
const int* all_ids,
pPromptTuningParam<T> prompt_param,
const int local_token_num,
const int64_t hidden_units,
const int step,
const int token_num,
const int ite,
const int seq_len,
const T scale)
{
// 1. lookup from embedding table
// 2. multiply scale
const int id_offset = step * token_num + ite * local_token_num;
for (int64_t index = blockIdx.x * blockDim.x + threadIdx.x; index < local_token_num * hidden_units;
index += blockDim.x * gridDim.x) {
const int word_index = index / hidden_units;
const int word_index_row = word_index / seq_len; // batch_id
const int col_index = index % hidden_units;
const int input_id = all_ids == nullptr ? word_index : all_ids[id_offset + word_index];
const int prompt_id = input_id - prompt_param.p_prompt_tuning_id_start;
T embedding = (T)0.0f;
if (PROMPT_SRC > 0 && prompt_id >= 0) {
if (PROMPT_SRC == 1) {
// from loaded prompt embedding tables
embedding =
prompt_param.p_prompt_tuning_batch_weights[word_index_row][prompt_id * hidden_units + col_index];
}
else {
// from request prompt embedding
embedding =
prompt_param
.request_prompt_embedding[word_index_row * prompt_param.request_prompt_max_length * hidden_units
+ prompt_id * hidden_units + col_index];
}
}
else {
embedding = embedding_table[input_id * hidden_units + col_index];
}
from_tensor[index] = embedding * scale;
}
}
#define EMBEDDING_LOOKUP(PROMPT_SRC) \
embeddingLookup<T, PROMPT_SRC><<<grid, block, 0, stream>>>(from_tensor, \
embedding_table, \
all_ids, \
prompt_param, \
local_token_num, \
hidden_units, \
step, \
token_num, \
ite, \
seq_len, \
scale);
/* Adapter function for invokeEmbeddingLookupPosEncoding{PadCount,InputLen} */
template<typename T>
void invokeEmbeddingLookupPosEncoding(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* padding_count,
const int* input_lengths,
pPromptTuningParam<T> prompt_param,
const int local_token_num,
const int hidden_units,
const T scale,
const int step,
const int max_input_length,
const int token_num,
const int ite,
const int seq_len,
cudaStream_t stream)
{
dim3 grid(min(local_token_num, 65536));
dim3 block(min(hidden_units, 1024));
if (position_encoding != nullptr) {
FT_CHECK_WITH_INFO(prompt_param.use_request_p_prompt_embedding == false
&& prompt_param.p_prompt_tuning_batch_weights == nullptr,
fmtstr("embeddingLookupPosEncoding still not support prompt tuning"));
embeddingLookupPosEncoding<T><<<grid, block, 0, stream>>>(from_tensor,
embedding_table,
position_encoding,
all_ids,
padding_count,
input_lengths,
local_token_num,
hidden_units,
step,
max_input_length,
token_num,
ite,
scale);
}
else {
if (prompt_param.use_request_p_prompt_embedding) {
EMBEDDING_LOOKUP(2);
}
else if (prompt_param.p_prompt_tuning_batch_weights != nullptr) {
EMBEDDING_LOOKUP(1);
}
else {
EMBEDDING_LOOKUP(0);
}
}
}
#undef EMBEDDING_LOOKUP
template<typename T>
void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* pad_count,
pPromptTuningParam<T> prompt_param,
const int local_token_num,
const int hidden_units,
const T scale,
const int step,
const int token_num,
const int ite,
const int seq_len,
cudaStream_t stream)
{
invokeEmbeddingLookupPosEncoding<T>(from_tensor,
embedding_table,
position_encoding,
all_ids,
pad_count,
nullptr,
prompt_param,
local_token_num,
hidden_units,
scale,
step,
0,
token_num,
ite,
seq_len,
stream);
}
#define INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(T) \
template void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor, \
const T* embedding_table, \
const T* position_encoding, \
const int* all_ids, \
const int* pad_count, \
pPromptTuningParam<T> prompt_param, \
const int local_token_num, \
const int hidden_units, \
const T scale, \
const int step, \
const int token_num, \
const int ite, \
const int seq_len, \
cudaStream_t stream)
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(float);
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(half);
#ifdef ENABLE_BF16
INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT(__nv_bfloat16);
#endif
#undef INSTANTIATE_LOOKUP_POS_ENCODING_PAD_COUNT
template<typename T>
__global__ void paddingEmbedding(T* padded_embedding_kernel,
T* padded_embedding_bias,
const T* embedding_kernel,
const T* embedding_bias,
const int64_t hidden_unit,
const int64_t vocab_size,
const int64_t vocab_size_padded)
{
for (int64_t id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
id += blockDim.x * gridDim.x) {
int row_id = id / vocab_size_padded;
int col_id = id % vocab_size_padded;
if (col_id < vocab_size) {
padded_embedding_kernel[id] = embedding_kernel[row_id * vocab_size + col_id];
}
else {
padded_embedding_kernel[id] = (T)(0.0f);
}
}
for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < vocab_size_padded; id += blockDim.x * gridDim.x) {
if (id < vocab_size) {
padded_embedding_bias[id] = embedding_bias[id];
}
else {
padded_embedding_bias[id] = (T)(0.0f);
}
}
}
template<typename T>
void invokePaddingEmbedding(T* padded_embedding_kernel,
T* padded_embedding_bias,
const T* embedding_kernel,
const T* embedding_bias,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));
paddingEmbedding<<<grid, block, 0, stream>>>(padded_embedding_kernel,
padded_embedding_bias,
embedding_kernel,
embedding_bias,
hidden_unit,
vocab_size,
vocab_size_padded);
}
template void invokePaddingEmbedding(float* padded_embedding_kernel,
float* padded_embedding_bias,
const float* embedding_kernel,
const float* embedding_bias,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
template void invokePaddingEmbedding(half* padded_embedding_kernel,
half* padded_embedding_bias,
const half* embedding_kernel,
const half* embedding_bias,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokePaddingEmbedding(__nv_bfloat16* padded_embedding_kernel,
__nv_bfloat16* padded_embedding_bias,
const __nv_bfloat16* embedding_kernel,
const __nv_bfloat16* embedding_bias,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
#endif
template<typename T>
__global__ void paddingEmbeddingKernel(T* padded_embedding_kernel,
const T* embedding_kernel,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded)
{
for (int id = threadIdx.x + blockIdx.x * blockDim.x; id < hidden_unit * vocab_size_padded;
id += blockDim.x * gridDim.x) {
int row_id = id / hidden_unit;
int col_id = id % hidden_unit;
if (row_id < vocab_size) {
padded_embedding_kernel[id] = embedding_kernel[row_id * hidden_unit + col_id];
}
else {
padded_embedding_kernel[id] = (T)(0.0f);
}
}
}
template<typename T>
void invokePaddingEmbeddingKernel(T* padded_embedding_kernel,
const T* embedding_kernel,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream)
{
dim3 block(512);
dim3 grid((int)(ceil(hidden_unit * vocab_size_padded / 512.)));
paddingEmbeddingKernel<<<grid, block, 0, stream>>>(
padded_embedding_kernel, embedding_kernel, hidden_unit, vocab_size, vocab_size_padded);
}
template void invokePaddingEmbeddingKernel(float* padded_embedding_kernel,
const float* embedding_kernel,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
template void invokePaddingEmbeddingKernel(half* padded_embedding_kernel,
const half* embedding_kernel,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokePaddingEmbeddingKernel(__nv_bfloat16* padded_embedding_kernel,
const __nv_bfloat16* embedding_kernel,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
#endif
__global__ void gatherTree(gatherTreeParam param)
{
// PREFIX SOFT PROMPT
// beam: have six parts
// [prompt | input | input_padding | prompt_padding | generated output | padding (use end_token)]
// parents: have five parts
// [prompt | input | input_padding | prompt_padding | generated output | padding (use 0)]
// step_ids: need to remove prompt, input_padding and prompt_padding
// the shape is [input_length + requested_output_length, bs, beam_width]
// need to transpose to output_ids [bs, beam_width, input_length + requested_output_length]
// max_input_length: input + input_padding + prompt_padding
// P/PROMPT TUNING
// NOTE: input (real ids | prompt virtual ids) have already been preprocessed during embedding lookup, no prompt
// templates now beam: [input (real ids | prompt virtual ids) | input_padding | generated output | padding (use
// end_token)] parents: [input (real ids | prompt virtual ids) | input_padding | generated output | padding (use
// 0)] step_ids: need to remove virtual prompt ids in input ids
// the shape is [input_length (real input length, prompt length) + requested_output_length, bs, beam_width]
// need to transpose to output_ids [bs, beam_width, input_length + requested_output_length]
// max_input_length: input (real ids | prompt virtual ids) + input_padding
const int max_input_length = param.input_lengths == nullptr ? 0 : param.max_input_length;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < param.batch_size * param.beam_width;
i += gridDim.x * blockDim.x) {
const int batch = i / param.beam_width;
const int beam = i % param.beam_width;
const int prompt_len =
param.prefix_soft_prompt_lengths == nullptr ? 0 : param.prefix_soft_prompt_lengths[batch];
int input_len = param.input_lengths == nullptr ? 0 : param.input_lengths[i];
// virtual prompts mean the prompt embedded in input ids (with prompt templates) [p/prompt tuning]
const int virtual_prompt_length =
param.p_prompt_tuning_prompt_lengths == nullptr ? 0 : param.p_prompt_tuning_prompt_lengths[batch];
// real input length (without virtual prompts) [p/prompt tuning]
input_len -= virtual_prompt_length;
const int* parent_ids = param.parent_ids;
const int* step_ids = param.step_ids;
// TODO(bhsueh) optimize the reduce_max operation for large beam_width
int max_len = -1;
bool update_response_input_length = param.response_input_lengths != nullptr;
// int selected_beam_index = 0;
for (int j = 0; j < param.beam_width; j++) {
int tmp_len =
param.max_sequence_lengths[batch * param.beam_width + j] + param.max_sequence_length_final_step;
// also remove the length of the soft prompts, p_prompt_tuning
param.max_sequence_lengths[batch * param.beam_width + j] =
tmp_len - param.max_prefix_soft_prompt_length
- (param.max_input_length - param.max_input_without_prompt_length);
// update the response input length
if (update_response_input_length) {
param.response_input_lengths[batch * param.beam_width + j] = input_len - prompt_len;
}
if (tmp_len > max_len) {
max_len = tmp_len;
// selected_beam_index = j;
}
}
const int max_seq_len_b = min(param.max_time, max_len);
if (max_seq_len_b <= 0) {
continue;
}
#define GET_IX(time_ix, beam_ix) \
(param.batch_size * param.beam_width * (time_ix) + param.beam_width * batch + (beam_ix))
const int padding_offset_and_prompt_offset = max_input_length - input_len + prompt_len;
const int initial_tgt_ix = GET_IX(max_seq_len_b - 1 - padding_offset_and_prompt_offset, beam);
const int initial_parent_ix = GET_IX(max_seq_len_b - 1, beam);
param.beams[initial_tgt_ix] = __ldg(step_ids + initial_parent_ix);
int parent = parent_ids == nullptr ? 0 : __ldg(parent_ids + initial_parent_ix) % param.beam_width;
bool found_bad = false;
for (int level = max_seq_len_b - 2; level >= 0; --level) {
if (level < prompt_len || (level >= input_len && level < max_input_length)) {
continue;
}
int tgt_level = level >= max_input_length ? level - padding_offset_and_prompt_offset : level - prompt_len;
const int level_beam_ix = GET_IX(tgt_level, beam);
const int level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > param.beam_width) {
// param.beams[level_beam_ix] = -1;
param.beams[level_beam_ix] = param.end_tokens[batch];
parent = -1;
found_bad = true;
}
else {
param.beams[level_beam_ix] = __ldg(step_ids + level_parent_ix);
parent = parent_ids == nullptr ? 0 : __ldg(parent_ids + level_parent_ix) % param.beam_width;
}
}
// set the padded part as end_token
// input_len
for (int index = max_len - padding_offset_and_prompt_offset;
index < param.max_time - param.max_prefix_soft_prompt_length;
++index) {
param.beams[GET_IX(index, beam)] = param.end_tokens[batch];
}
// Not necessary when using a BeamSearchDecoder, but necessary
// when a user feeds in possibly broken trajectory (i.e., non-eos
// entries in a beam following eos entries).
if (!found_bad) {
bool finished = false;
// skip the step 0 because it is often the start token
int start_step = max_input_length == 0 ? 1 : max_input_length;
for (int time = start_step; time < max_seq_len_b; ++time) {
const int level_beam_ix = GET_IX(time, beam);
if (finished) {
param.beams[level_beam_ix] = param.end_tokens[batch];
}
else if (param.beams[level_beam_ix] == param.end_tokens[batch]) {
finished = true;
}
}
}
#undef GET_IX
// transpose on output_ids
// remove p_prompt tuning virtual tokens (end tokens)
int actual_output_length = param.max_time - param.max_prefix_soft_prompt_length
- (param.max_input_length - param.max_input_without_prompt_length);
if (param.output_ids != nullptr) {
for (int j = 0; j < actual_output_length; j++) {
param.output_ids[i * actual_output_length + j] =
param.beams[j * param.batch_size * param.beam_width + i];
}
}
}
}
void invokeGatherTree(int* beams,
int* max_sequence_lengths,
const int max_time,
const int batch_size,
const int beam_width,
const int* step_ids,
const int* parent_ids,
const int* end_tokens,
cudaStream_t stream)
{
gatherTreeParam param;
param.beams = beams;
param.max_sequence_lengths = max_sequence_lengths;
param.max_time = max_time;
param.batch_size = batch_size;
param.beam_width = beam_width;
param.step_ids = step_ids;
param.parent_ids = parent_ids;
param.end_tokens = end_tokens;
param.max_input_length = 1;
param.prefix_soft_prompt_lengths = nullptr;
param.stream = stream;
invokeGatherTree(param);
}
void invokeGatherTree(int* beams,
int* max_sequence_lengths,
const int max_time,
const int batch_size,
const int beam_width,
const int* step_ids,
const int* parent_ids,
const int* end_tokens,
const int max_input_length,
cudaStream_t stream)
{
gatherTreeParam param;
param.beams = beams;
param.max_sequence_lengths = max_sequence_lengths;
param.max_time = max_time;
param.batch_size = batch_size;
param.beam_width = beam_width;
param.step_ids = step_ids;
param.parent_ids = parent_ids;
param.end_tokens = end_tokens;
param.max_input_length = max_input_length;
param.prefix_soft_prompt_lengths = nullptr;
param.stream = stream;
invokeGatherTree(param);
}
void invokeGatherTree(gatherTreeParam param)
{
int batchbeam = param.batch_size * param.beam_width;
dim3 grid(1), block(batchbeam);
// though decoder do not support > 1024 for now
if (batchbeam > 1024) {
grid.x = ceil(param.batch_size * param.beam_width / 1024.);
block.x = 1024;
}
gatherTree<<<grid, block, 0, param.stream>>>(param);
}
__global__ void minusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num)
{
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < token_num; i += blockDim.x * gridDim.x) {
if (finished[i] == false) {
sequence_lengths[i] -= 1;
}
}
}
void invokeMinusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream)
{
dim3 block(min(256, token_num));
dim3 grid(ceil(token_num / 256.));
minusUnfinishedSeqlen<<<block, grid, 0, stream>>>(sequence_lengths, finished, token_num);
}
__global__ void plusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num)
{
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < token_num; i += blockDim.x * gridDim.x) {
if (finished[i] == false) {
sequence_lengths[i] += 1;
}
}
}
void invokePlusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream)
{
dim3 block(min(256, token_num));
dim3 grid(ceil(token_num / 256.));
plusUnfinishedSeqlen<<<block, grid, 0, stream>>>(sequence_lengths, finished, token_num);
}
template<typename T>
__global__ void plusScalar(T* buf, const T val, const int size)
{
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += blockDim.x * gridDim.x) {
buf[i] += val;
}
}
template<typename T>
void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream)
{
dim3 block(min(256, size));
dim3 grid(ceil(size / 256.));
plusScalar<<<block, grid, 0, stream>>>(buf, val, size);
}
template void invokePlusScalar(int* buf, const int val, const int size, cudaStream_t stream);
__global__ void finalize(int* output_ids,
int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
const int* topk_output_ids,
const int* topk_sequence_lengths,
const float* scores,
const float* topk_cum_log_probs,
const float* topk_log_probs,
const int* num_beams,
const int beam_width,
const int max_seq_len)
{
// output_ids: [bs, beam_width, max_seq_len]
// sequence_lengths: [bs, beam_width]
// cum_log_probs: [bs, beam_width]
// output_log_probs: [bs, beam_width, max_seq_len]
// topk_output_ids: [bs, 2 * beam_width, max_seq_len + 1]
// topk_sequence_lengths: [bs, 2 * beam_width]
// scores: [bs, 2 * beam_width]
// topk_cum_log_probs: [bs, 2 * beam_width]
// topk_log_probs: [bs, 2 * beam_width, max_seq_len + 1]
// num_beams: [bs]
// This kernel do a sorting for scores first, and then put the topk_output_ids
// into output_ids by the rank of scores.
// Note that we remove the start_token (the id at first position) from topk_output_ids
extern __shared__ char array[];
int* rank = (int*)(array);
float* s_scores = (float*)(rank + beam_width);
if (threadIdx.x < num_beams[blockIdx.x]) {
s_scores[threadIdx.x] = scores[blockIdx.x * beam_width * 2 + threadIdx.x];
}
__syncthreads();
for (int i = 0; i < beam_width; i++) {
float score = threadIdx.x < num_beams[blockIdx.x] ? s_scores[threadIdx.x] : -FLT_MAX;
float max_score = blockReduceMax<float>(score);
if (threadIdx.x == 0) {
for (int j = 0; j < beam_width * 2; j++) {
if (s_scores[j] == max_score) {
rank[i] = j;
s_scores[j] = -FLT_MAX;
break;
}
}
}
__syncthreads();
}
if (threadIdx.x < beam_width) {
sequence_lengths[blockIdx.x * beam_width + threadIdx.x] =
topk_sequence_lengths[blockIdx.x * beam_width * 2 + rank[threadIdx.x]];
if (cum_log_probs != nullptr) {
cum_log_probs[blockIdx.x * beam_width + threadIdx.x] =
topk_cum_log_probs[blockIdx.x * beam_width * 2 + rank[threadIdx.x]];
}
}
for (int beam_idx = 0; beam_idx < beam_width; beam_idx++) {
// start from step 1 to skip the start token
for (int i = threadIdx.x; i < sequence_lengths[blockIdx.x * beam_width + beam_idx]; i += blockDim.x) {
output_ids[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] =
topk_output_ids[blockIdx.x * (beam_width * 2) * (max_seq_len + 1) + rank[beam_idx] * (max_seq_len + 1)
+ (i + 1)];
if (output_log_probs != nullptr) {
output_log_probs[blockIdx.x * beam_width * max_seq_len + beam_idx * max_seq_len + i] =
topk_log_probs[blockIdx.x * (beam_width * 2) * (max_seq_len + 1)
+ rank[beam_idx] * (max_seq_len + 1) + (i + 1)];
}
}
}
}
void invokeFinalize(int* output_ids,
int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
const int* topk_output_ids,
const int* topk_sequence_lengths,
const float* scores,
const float* topk_cum_log_probs,
const float* topk_log_probs,
const int* num_beams,
const int beam_width,
const int max_seq_len,
const int batch_size,
cudaStream_t stream)
{
dim3 block(beam_width * 2);
block.x = (block.x + 31) / 32 * 32;
FT_CHECK(block.x < 1024);
finalize<<<batch_size, block, beam_width * sizeof(int) + (beam_width * 2) * sizeof(float), stream>>>(
output_ids,
sequence_lengths,
cum_log_probs,
output_log_probs,
topk_output_ids,
topk_sequence_lengths,
scores,
topk_cum_log_probs,
topk_log_probs,
num_beams,
beam_width,
max_seq_len);
}
} // namespace fastertransformer
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "gpt_kernels.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
namespace fastertransformer {
template<typename T>
void invokeDecodingInitialize(bool* finished,
int* sequence_length,
int* word_ids,
T* cum_log_probs,
const int* sentence_ids,
const int batch_size,
const int beam_width,
const int max_input_length,
cudaStream_t stream);
// get token from all_ids at step, then lookup from the embedding table
// by the token
template<typename T>
void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* padding_count,
pPromptTuningParam<T> prompt_param,
const int local_token_num,
const int hidden_units,
const T scale,
const int step,
const int token_num,
const int ite,
const int seq_len,
cudaStream_t stream);
template<typename T>
void invokeEmbeddingLookupPosEncodingPadCount(T* from_tensor,
const T* embedding_table,
const T* position_encoding,
const int* all_ids,
const int* padding_count,
const int local_token_num,
const int hidden_units,
const T scale,
const int step,
const int token_num,
const int ite,
cudaStream_t stream)
{
invokeEmbeddingLookupPosEncodingPadCount(from_tensor,
embedding_table,
position_encoding,
all_ids,
padding_count,
{(const T**)nullptr, 0, 0, false, nullptr},
local_token_num,
hidden_units,
scale,
step,
token_num,
ite,
0,
stream);
}
template<typename T>
void invokePaddingEmbedding(T* padded_embedding_kernel,
T* padded_embedding_bias,
const T* embedding_kernel,
const T* embedding_bias,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
template<typename T>
void invokePaddingEmbeddingKernel(T* padded_embedding_kernel,
const T* embedding_kernel,
const int hidden_unit,
const int vocab_size,
const int vocab_size_padded,
cudaStream_t stream);
void invokeGatherTree(int* beams,
int* max_sequence_lengths,
const int max_time,
const int batch_size,
const int beam_width,
const int* step_ids,
const int* parent_ids,
const int* end_tokens,
cudaStream_t stream);
void invokeGatherTree(int* beams,
int* max_sequence_lengths,
const int max_time,
const int batch_size,
const int beam_width,
const int* step_ids,
const int* parent_ids,
const int* end_tokens,
const int max_input_length,
cudaStream_t stream);
struct gatherTreeParam {
int* beams = nullptr;
int* max_sequence_lengths = nullptr;
int max_sequence_length_final_step = 0;
const int* input_lengths = nullptr;
// response input lengths (used to slice the ids during postprocessing)
int* response_input_lengths = nullptr;
int max_time = 0;
int batch_size = 0;
int beam_width = 0;
const int* step_ids = nullptr;
const int* parent_ids = nullptr;
const int* end_tokens = nullptr;
int max_input_length = 0;
const int* prefix_soft_prompt_lengths = nullptr;
// p_prompt_tuning prompt leangths, used to remove prompts during post-processing
const int* p_prompt_tuning_prompt_lengths = nullptr;
int max_input_without_prompt_length = 0;
// prefix soft prompt
int max_prefix_soft_prompt_length = 0;
int* output_ids = nullptr;
cudaStream_t stream;
};
void invokeGatherTree(gatherTreeParam param);
void invokeMinusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream);
void invokePlusUnfinishedSeqlen(int* sequence_lengths, const bool* finished, const int token_num, cudaStream_t stream);
template<typename T>
void invokePlusScalar(T* buf, const T val, const int size, cudaStream_t stream);
void invokeFinalize(int* output_ids,
int* sequence_lengths,
float* cum_log_probs,
float* output_log_probs,
const int* topk_output_ids,
const int* topk_sequence_lengths,
const float* scores,
const float* topk_cum_log_probs,
const float* topk_log_probs,
const int* num_beams,
const int beam_width,
const int max_seq_len,
const int batch_size,
cudaStream_t stream);
} // namespace fastertransformer
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cublas_v2.h"
#include "gen_relative_pos_bias.h"
#include "reduce_kernel_utils.cuh"
#include "src/fastertransformer/kernels/activation_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include <cstdio>
namespace fastertransformer {
/******************* invokeGenRelativePosBias ***********************/
// relative_position_bias_table is [(2*window_size-1)*(2*window_size-1), headNum]
// relative_position_bias is [head_num, window_size^2, window_size^2]
// grid(window_size*window_size, head_num)
// block(window_size*window_size)
template<typename T, typename Tindex>
__global__ void gen_relative_pos_bias(T* relative_position_bias,
const T* relative_position_bias_table,
const Tindex* relative_position_bias_index,
const int window_size,
const int head_num)
{
const int h_in_window = blockIdx.x / window_size;
const int w_in_window = blockIdx.x % window_size;
const int h_in_token = threadIdx.x / window_size;
const int w_in_token = threadIdx.x % window_size;
const int head_idx = blockIdx.y;
const int elements_per_window = window_size * window_size;
const size_t elements_per_window_2 = elements_per_window * elements_per_window;
const size_t output_idx = head_idx * elements_per_window_2 + blockIdx.x * elements_per_window + threadIdx.x;
if (output_idx < head_num * elements_per_window_2) {
const Tindex idx_in_table =
relative_position_bias_index[(h_in_window * window_size + w_in_window) * elements_per_window
+ h_in_token * window_size + w_in_token];
relative_position_bias[output_idx] = relative_position_bias_table[idx_in_table * head_num + head_idx];
}
}
template<typename T, typename Tindex>
void invokeGenRelativePosBias(T* relative_position_bias,
const T* relative_position_bias_table,
const Tindex* relative_position_bias_index,
const int window_size,
const int head_num,
cudaStream_t stream)
{
dim3 grid(window_size * window_size, head_num);
dim3 block(window_size * window_size);
if (block.x > 1024) {
printf("[ERROR][invokeGenRelativePosBias] window_size*window_size > 1024.\n");
exit(-1);
}
gen_relative_pos_bias<<<grid, block, 0, stream>>>(
relative_position_bias, relative_position_bias_table, relative_position_bias_index, window_size, head_num);
}
/******************* invokeGenRelativePosBiasV2 ***********************/
template<typename T, typename Tindex>
void invokeGenRelativePosBiasV2(T* relative_position_bias,
const T* relative_coords_table,
const Tindex* relative_position_bias_index,
const T* cpb_mlp_weight1,
const T* cpb_mlp_bias1,
const T* cpb_mlp_weight2,
const int window_size,
const int cpb_mlp_in_dim,
const int cpb_mlp_out_dim,
const int head_num,
cudaStream_t stream)
{
dim3 grid(window_size * window_size, head_num);
dim3 block(window_size * window_size);
if (block.x > 1024) {
printf("[ERROR][invokeGenRelativePosBias] window_size*window_size > 1024.\n");
exit(-1);
}
T* relative_position_bias_table;
check_cuda_error(cudaMalloc(&relative_position_bias_table,
((2 * window_size - 1) * (2 * window_size - 1) * head_num) * sizeof(T)));
T* cpb_mlp_1;
check_cuda_error(
cudaMalloc(&cpb_mlp_1, ((2 * window_size - 1) * (2 * window_size - 1) * cpb_mlp_out_dim) * sizeof(T)));
cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle));
int m = (2 * window_size - 1) * (2 * window_size - 1);
T alpha = (T)1.0f;
T beta = (T)0.0f;
cudaDataType_t type = std::is_same<float, T>::value ? CUDA_R_32F : CUDA_R_16F;
#if (CUDART_VERSION >= 11000)
cublasComputeType_t compute_type = std::is_same<float, T>::value ? CUBLAS_COMPUTE_32F : CUBLAS_COMPUTE_16F;
#else
cudaDataType_t compute_type = std::is_same<float, T>::value ? CUDA_R_32F : CUDA_R_16F;
#endif
cublasGemmAlgo_t algo = std::is_same<float, T>::value ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP;
check_cuda_error(cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
cpb_mlp_out_dim,
m,
cpb_mlp_in_dim,
&alpha,
cpb_mlp_weight1,
type,
cpb_mlp_in_dim,
relative_coords_table,
type,
cpb_mlp_in_dim,
&beta,
cpb_mlp_1,
type,
cpb_mlp_out_dim,
compute_type,
algo));
invokeGenericActivation<ReluActivation, T, T>(
cpb_mlp_1, cpb_mlp_bias1, nullptr, nullptr, nullptr, nullptr, m, cpb_mlp_out_dim, 0, nullptr, nullptr, stream);
check_cuda_error(cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
head_num,
m,
cpb_mlp_out_dim,
&alpha,
cpb_mlp_weight2,
type,
cpb_mlp_out_dim,
cpb_mlp_1,
type,
cpb_mlp_out_dim,
&beta,
relative_position_bias_table,
type,
head_num,
compute_type,
algo));
gen_relative_pos_bias<<<grid, block, 0, stream>>>(
relative_position_bias, relative_position_bias_table, relative_position_bias_index, window_size, head_num);
invokeSigmoid(
relative_position_bias, window_size * window_size * window_size * window_size * head_num, 16.0f, stream);
check_cuda_error(cudaFree(relative_position_bias_table));
check_cuda_error(cudaFree(cpb_mlp_1));
check_cuda_error(cublasDestroy(cublas_handle));
}
/******************* instantiation ***********************/
template void invokeGenRelativePosBias(float* relative_position_bias,
const float* relative_position_bias_table,
const int* relative_position_bias_index,
const int window_size,
const int head_num,
cudaStream_t stream);
template void invokeGenRelativePosBias(half* relative_position_bias,
const half* relative_position_bias_table,
const int* relative_position_bias_index,
const int window_size,
const int head_num,
cudaStream_t stream);
template void invokeGenRelativePosBias(float* relative_position_bias,
const float* relative_position_bias_table,
const int64_t* relative_position_bias_index,
const int window_size,
const int head_num,
cudaStream_t stream);
template void invokeGenRelativePosBias(half* relative_position_bias,
const half* relative_position_bias_table,
const int64_t* relative_position_bias_index,
const int window_size,
const int head_num,
cudaStream_t stream);
__host__ __device__ uint32_t pow2_rounddown(uint32_t x)
{
x |= x >> 1;
x |= x >> 2;
x |= x >> 4;
x |= x >> 8;
x |= x >> 16;
x >>= 1;
return x + 1;
}
template<typename T>
__global__ void generate_alibi_slopes(T* alibi_slopes, const size_t num_heads)
{
if (threadIdx.x < num_heads) {
// The nearest power of 2 greater than num_heads followed by HF's implementation.
int num_heads_pow2 = pow2_rounddown(num_heads);
// Loop over the attention head.
for (int h = threadIdx.x; h < num_heads; h += blockDim.x) {
if (h < num_heads_pow2) {
alibi_slopes[h] = static_cast<T>(powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2) - 3.f)), h + 1));
}
else {
alibi_slopes[h] = static_cast<T>(
powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2 << 1) - 3.f)), (h - num_heads_pow2) * 2 + 1));
}
}
}
}
template<typename T>
void invokeBuildAlibiSlopes(T* alibi_slopes, const size_t num_heads, cudaStream_t stream)
{
// Generate the slopes of a linear attention linear bias.
//
// Paper: https://arxiv.org/abs/2108.12409
// HF's implementation
// https://github.com/huggingface/transformers/blob/56ef0ba44765162f830873c140bd40bdc975cc34/src/transformers/models/bloom/modeling_bloom.py#L86
// Author's implementation
// https://github.com/ofirpress/attention_with_linear_biases/blob/02aa87e7a29e9340efd28d6d169018eafb3aa57a/fairseq/models/transformer.py#L760
//
// alibi_slopes: [num_heads],
// strictly follows how HF implements. which treats power-of-2 heads, and non-power-of-2 heads differently.
// what paper generates differs with HF's when number of heads is not a power of 2.
// num_heads: the number of attention heads.
// stream: a cuda stream.
dim3 block(min((int)num_heads, 512));
generate_alibi_slopes<<<1, block, 0, stream>>>(alibi_slopes, num_heads);
}
template void invokeBuildAlibiSlopes(float* alibi_slopes, const size_t num_heads, cudaStream_t stream);
template void invokeBuildAlibiSlopes(half* alibi_slopes, const size_t num_heads, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBuildAlibiSlopes(__nv_bfloat16* alibi_slopes, const size_t num_heads, cudaStream_t stream);
#endif
template void invokeGenRelativePosBiasV2(float* relative_position_bias,
const float* relative_coords_table,
const int* relative_position_bias_index,
const float* cpb_mlp_weight1,
const float* cpb_mlp_bias1,
const float* cpb_mlp_weight2,
const int window_size,
const int cpb_mlp_in_dim,
const int cpb_mlp_out_dim,
const int head_num,
cudaStream_t stream);
template void invokeGenRelativePosBiasV2(half* relative_position_bias,
const half* relative_coords_table,
const int* relative_position_bias_index,
const half* cpb_mlp_weight1,
const half* cpb_mlp_bias1,
const half* cpb_mlp_weight2,
const int window_size,
const int cpb_mlp_in_dim,
const int cpb_mlp_out_dim,
const int head_num,
cudaStream_t stream);
template void invokeGenRelativePosBiasV2(float* relative_position_bias,
const float* relative_coords_table,
const int64_t* relative_position_bias_index,
const float* cpb_mlp_weight1,
const float* cpb_mlp_bias1,
const float* cpb_mlp_weight2,
const int window_size,
const int cpb_mlp_in_dim,
const int cpb_mlp_out_dim,
const int head_num,
cudaStream_t stream);
template void invokeGenRelativePosBiasV2(half* relative_position_bias,
const half* relative_coords_table,
const int64_t* relative_position_bias_index,
const half* cpb_mlp_weight1,
const half* cpb_mlp_bias1,
const half* cpb_mlp_weight2,
const int window_size,
const int cpb_mlp_in_dim,
const int cpb_mlp_out_dim,
const int head_num,
cudaStream_t stream);
} // namespace fastertransformer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment