Unverified Commit ab1767cf authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

TurboMind 2 (#590)

* refresh decoder attention kernel

* block-level kv cache

* `BlockManager` & `SequenceManager`

* update

* update

* update

* update

* rename

* GQA support

* fix context length

* GQA dispatch

* kv8

* tune

* async stream cb

* nvtx

* config parsing

* debug

* optimize output cost

* split-k decoding

* minor

* truncate `session_len` by available blocks

* minor

* license

* fix

* dispatch `cp.async`

* fix linking

* fix

* fix deadlock

* guard input length

* correct start offset

* fix prefill chunking

* fix `cache_block_seq_len` param passing

* fix `block_size` fmtstr

* fix output tokens

* fix batch resizing

* fix masking of finished sequences

* add debug util

* free unused block early

* add ntk scaling and logn scaling

* cmake flags

* fix typo

* w4a16 for sm75

* fix msvc build

* fix msvc build

* fix block verification

* fix msvc build

* use `std::shuffle`

* fix lint

* fix lint

* fix lint

* clear incoming buffer

* clear finished requests

* fix batch initialization

* fix typo

* fix typo

* fix comparison
parent 06125966
......@@ -72,3 +72,5 @@ work_dir*/
*.out
*.csv
*.pkl
!CMakeLists.txt
......@@ -61,6 +61,22 @@ option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF
option(BUILD_FAST_MATH "Build in fast math mode" ON)
# the environment variable
# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
# must be set at runtime
# https://github.com/google/sanitizers/issues/1322
if (LMDEPLOY_ASAN_ENABLE)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>)
add_link_options(-fsanitize=address)
endif ()
# notice that ubsan has linker issues for ubuntu < 18.04, see
# https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed
if (LMDEPLOY_UBSAN_ENABLE)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=undefined>)
add_link_options(-fsanitize=undefined)
endif ()
if(BUILD_MULTI_GPU)
message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL")
add_definitions("-DBUILD_MULTI_GPU")
......@@ -181,11 +197,15 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
# set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
if(BUILD_FAST_MATH)
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math")
message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}")
endif()
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
......@@ -252,11 +272,15 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');"
OUTPUT_VARIABLE USE_CXX11_ABI)
message("-- USE_CXX11_ABI=${USE_CXX11_ABI}")
if (USE_CXX11_ABI)
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
else()
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")
......@@ -327,6 +351,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:cuda_utils>
$<TARGET_OBJECTS:custom_ar_comm>
$<TARGET_OBJECTS:custom_ar_kernels>
$<TARGET_OBJECTS:decoder_multihead_attention>
$<TARGET_OBJECTS:decoder_masked_multihead_attention>
$<TARGET_OBJECTS:decoding_kernels>
$<TARGET_OBJECTS:gpt_kernels>
......
......@@ -80,8 +80,10 @@ broadCastRequest(const std::vector<int>& v_start_ids,
if (node_id == 0) {
memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int));
memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int));
if (!v_input_bad_words.empty()) {
memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int));
}
}
if (kUSE_MPI) {
ft::mpi::barrier();
}
......@@ -431,6 +433,8 @@ int main(int argc, char* argv[])
const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1];
const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2];
ft::FT_CHECK(beam_width == 1);
std::vector<int> seq_lens(batch_size);
// step 6: check results
if (node_id == 0) {
......@@ -440,32 +444,25 @@ int main(int argc, char* argv[])
printf("[WARNING] Cannot write results into output file %s \n", fName.c_str());
}
else {
size_t outCount = batch_size * beam_width * seq_len;
// int* hBuf = new int[outCount];
const size_t outCount = batch_size * beam_width * seq_len;
std::vector<int> hBuf(outCount);
ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
std::cout << "sequence length: ";
for (int i = 0; i < batch_size; ++i) {
std::cout << (i ? ", " : "") << seq_lens[i];
}
std::cout << "\n";
{
std::cout << "Writing " << outCount << " elements\n";
int zeroCount = 0;
for (size_t i = 0; i < outCount; i++) {
if (hBuf[i] == int(0))
zeroCount++;
outFile << hBuf[i] << " ";
if ((i + 1) % (seq_len) == 0)
outFile << std::endl;
if (i < 10)
printf("%5d ", hBuf[i]);
if ((i + 1) % (seq_len) == 0 && i < 10)
std::cout << std::endl;
for (int i = 0; i < batch_size; ++i) {
outFile << (i ? "\n" : "");
auto buf = hBuf.data() + seq_len * i;
for (int j = 0; j < seq_lens[i]; ++j) {
outFile << buf[j] << " ";
}
std::cout << std::endl << "zeroCount = " << zeroCount << std::endl;
}
}
}
......@@ -475,7 +472,7 @@ int main(int argc, char* argv[])
}
cudaDeviceSynchronize();
if (1) {
if (0) {
// test time
auto start = std::chrono::high_resolution_clock::now();
......
......@@ -71,3 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_subdirectory(gemm_s_f16)
add_subdirectory(decoder_multihead_attention)
......@@ -43,6 +43,12 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
// cudaFuncAttributes attr{}; \
// cudaFuncGetAttributes(&attr, func); \
// std::cout << "static_smem_sz: " << attr.sharedSizeBytes << std::endl; \
// std::cout << "max_dynamic_smem: " << attr.maxDynamicSharedSizeBytes << std::endl; \
// std::cout << "dynamic_smem_sz: " << smem_sz << std::endl; \
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
......
......@@ -1472,6 +1472,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
}
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
printf("QK_last[%d] = %f\n", hi, qk);
qk_max = qk;
qk_smem[tlength - first_step] = qk;
// qk_smem[params.timestep] = qk;
......@@ -1596,6 +1598,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
}
// printf("QK_%d = %f\n", (int)ti, qk);
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
qk_smem[ti - first_step] = qk;
}
......@@ -1632,6 +1635,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Broadcast to all the threads in the warp.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
if (threadIdx.x == 0) {
printf("QK_MAX[%d] = %f\n", hi, (float)qk_max);
}
// Compute the logits and start the sum.
float sum = 0.f;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
......@@ -1657,6 +1664,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Compute the sum.
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
if (threadIdx.x == 0) {
printf("SUM[%d] = %f\n", hi, (float)sum);
}
// Normalize the logits.
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
......
# Copyright (c) OpenMMLab. All rights reserved.
add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu)
# target_compile_options(decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
# target_compile_options(test_decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
target_link_libraries(test_decoder_multihead_attention PRIVATE
decoder_multihead_attention
decoder_masked_multihead_attention
cublas)
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "src/turbomind/kernels/gemm_s_f16/common.h"
#include <cfloat>
#include <limits>
namespace turbomind {
namespace ops {
template<typename T>
struct plus {
__device__ T operator()(T a, T b)
{
return a + b;
}
};
template<typename T>
struct minus {
__device__ T operator()(T a, T b)
{
return a - b;
}
};
template<typename T>
struct multiplies {
__device__ T operator()(T a, T b)
{
return a * b;
}
};
template<typename T, int N, typename Op>
inline __device__ Array<T, N> binary_op_vv(const Array<T, N>& a, const Array<T, N>& b, Op op)
{
Array<T, N> c;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
c[i] = op(a[i], b[i]);
}
return c;
}
template<typename T, int N, typename Op>
inline __device__ Array<T, N> binary_op_sv(const T& a, const Array<T, N>& b, Op op)
{
Array<T, N> c;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
c[i] = op(a, b[i]);
}
return c;
}
template<typename T, int N, typename Op>
inline __device__ Array<T, N> binary_op_vs(const Array<T, N>& a, const T& b, Op op)
{
Array<T, N> c;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
c[i] = op(a[i], b);
}
return c;
}
template<typename T, int N>
inline __device__ Array<T, N> operator+(const Array<T, N>& a, const Array<T, N>& b)
{
return binary_op_vv(a, b, plus<T>{});
}
template<typename T, int N>
inline __device__ Array<T, N> operator*(const Array<T, N>& a, const Array<T, N>& b)
{
return binary_op_vv(a, b, multiplies<T>{});
}
template<typename T, int N>
inline __device__ Array<T, N> operator*(const Array<T, N>& a, const T& b)
{
return binary_op_vs(a, b, multiplies<T>{});
}
} // namespace ops
template<typename To, typename From, int N>
inline __device__ Array<To, N> cast(const Array<From, N>& src)
{
Array<To, N> dst;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = (To)src[i];
}
return dst;
}
template<int N>
struct RotaryEmbedding {
static_assert(N % 2 == 0);
Array<float, N> cs_;
__device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
const float2 tmp = get_coefficient(offset.x + i, dims, base, timestep);
cs_[i] = tmp.x;
cs_[i + 1] = tmp.y;
}
}
static __device__ inline float2 get_coefficient(int idx, int dims, float base, int timestep)
{
const float inv_freq = timestep / powf(base, idx / (float)dims);
float2 cs;
sincosf(inv_freq, &cs.y, &cs.x);
return cs;
}
template<typename T>
__device__ void apply(Array<T, N>& x)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float tmp0 = cs_[i] * (float)x[i] - cs_[i + 1] * (float)x[i + 1];
float tmp1 = cs_[i] * (float)x[i + 1] + cs_[i + 1] * (float)x[i];
x[i] = (T)tmp0;
x[i + 1] = (T)tmp1;
}
}
};
struct LogNScaling {
float scale_;
__device__ static float get_scale(int seq_len, int max_position_embeddings)
{
if (seq_len <= max_position_embeddings) {
return 1.f;
}
else {
return log2f(seq_len) / log2f(max_position_embeddings);
}
}
__device__ LogNScaling(int seq_len, int max_position_embeddings)
{
scale_ = get_scale(seq_len, max_position_embeddings);
}
template<typename T, int N>
__device__ void apply(Array<T, N>& x) const
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
x[i] = (T)((float)x[i] * scale_);
}
}
};
template<typename T, int N>
inline __device__ void Store(T* dst, const Array<T, N>& src)
{
static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
*(uint4*)dst = (const uint4&)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
*(uint2*)dst = (const uint2&)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {
*(uint1*)dst = (const uint1&)src;
}
else {
static_assert(!std::is_same_v<T, T>);
}
}
template<typename T, int N>
inline __device__ void Ldg(Array<T, N>& dst, const T* src)
{
static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
(uint4&)dst = __ldg((const uint4*)src);
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
(uint2&)dst = __ldg((const uint2*)src);
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {
(uint&)dst = __ldg((const uint*)src);
}
else {
static_assert(!std::is_same_v<T, T>);
}
}
template<typename T, int N>
inline __device__ void Lds(Array<T, N>& dst, const T* src)
{
static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
(uint4&)dst = *(const uint4*)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
(uint2&)dst = *(const uint2*)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {
(uint1&)dst = *(const uint1*)src;
}
else {
static_assert(!std::is_same_v<T, T>);
}
}
template<typename Accum, typename Compute, int kThreadGroupSize, typename Tq, typename Tk, int N, int V>
inline __device__ Accum qk_dot(const Array<Tq, N> (&q)[V], const Array<Tk, N> (&k)[V])
{
Accum accum{};
PRAGMA_UNROLL
for (int vi = 0; vi < V; ++vi) {
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
accum += Accum(Compute(q[vi][i]) * Compute(k[vi][i]));
}
}
PRAGMA_UNROLL
for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
}
return accum;
}
template<typename Accum, typename Compute, int kThreadGroupSize, typename Tq, typename Tk, int N>
inline __device__ Accum qk_dot(const Array<Tq, N>& q, const Array<Tk, N>& k)
{
Accum accum{};
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
accum += Accum(Compute(q[i]) * Compute(k[i]));
}
PRAGMA_UNROLL
for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
}
return accum;
}
template<typename ComputeType, typename Tp, typename Tv, typename To, int N, int M>
inline __device__ void fma_pv(Tp pr, const Array<Tv, N> (&v)[M], Array<To, N> (&o)[M])
{
PRAGMA_UNROLL
for (int m = 0; m < M; ++m) {
PRAGMA_UNROLL
for (int n = 0; n < N; ++n) {
o[m][n] += To(ComputeType(v[m][n]) * ComputeType(pr));
}
}
}
template<typename ThreadMap, typename T, int N>
inline __device__ Array<T, N> qk_max(Array<T, N> val, T* smem_red, int warp_id, int lane_id)
{
constexpr int kWarpCount = ThreadMap::kWarpCount;
// warp maximum
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
PRAGMA_UNROLL
for (int mask = WARP_SIZE / 2; mask >= ThreadMap::kWarpThreadC; mask /= 2) {
val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
}
if (lane_id == 0) {
smem_red[i * kWarpCount + warp_id] = val[i];
}
}
__syncthreads();
// block maximum
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : -FLT_MAX;
PRAGMA_UNROLL
for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
}
// braodcast to all threads
val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
}
return val;
}
template<int kWarpCount, typename T, int N>
inline __device__ Array<T, N> blockSum(Array<T, N> val, T* smem_red, int warp_id, int lane_id)
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
PRAGMA_UNROLL
for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
}
if (lane_id == 0) {
smem_red[i * kWarpCount + warp_id] = val[i];
}
}
__syncthreads();
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : T{};
PRAGMA_UNROLL
for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
}
val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
}
return val;
}
//////////////////////////////////////////////////////////////////////////////////////////////////
// generic case for floating point -> floating point / integer -> integer conversion
template<typename Ti, typename To, typename = void>
struct ConvertKvCache {
__device__ __host__ ConvertKvCache(float, float) {}
template<int N>
inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<To, N>
{
Array<To, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
vo[i] = (To)vi[i];
}
return vo;
}
};
// generic case for converting to same type, bypass
template<typename T>
struct ConvertKvCache<T, T> {
__device__ __host__ ConvertKvCache(float, float) {}
template<int N>
inline __device__ auto operator()(const Array<T, N>& v) const -> Array<T, N>
{
return v;
}
};
template<typename Ti>
struct ConvertKvCache<Ti, int8_t> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
inline __device__ uint8_t round(float x) const
{
uint32_t y;
asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
return y;
}
template<int N>
inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<int8_t, N>
{
Array<int8_t, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
// convert to unsigned int by offsetting +128
(uint8_t&)vo[i] = round(((float)vi[i] - zero_) / scale_ + 128.f);
}
return vo;
}
};
inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
{
union {
Array<float, 4> f32x4;
Array<uint32_t, 4> u32x4;
};
auto& i8s = (const uint32_t&)x;
// 00000000111111112222222233333333
// 01234567012345670123456701234567
// SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
// 0????????_______XXXXXXXX________
// (1 + x / 2^15) * 2^(e - 127) -> e - 127 == 15 -> e = 142
// 7 6 5 4
static constexpr uint32_t f32_magic = 0x47000000; // 2^15 = 32768
static constexpr uint32_t m0 = 0x7604;
static constexpr uint32_t m1 = 0x7614;
static constexpr uint32_t m2 = 0x7624;
static constexpr uint32_t m3 = 0x7634;
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
if (0) { // fused with dequantization
PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
f32x4[i] -= 32896.f; // 32768 + 128
}
}
return f32x4;
}
template<>
struct ConvertKvCache<int8_t, float> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
{
zero_ = zero_ - 32896.f * scale_;
}
template<int N>
inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<float, N>
{
Array<float, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 4) {
auto& vec = (Array<float, 4>&)vo[i];
vec = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
vec[j] = vec[j] * scale_ + zero_;
// vec[j] = vec[j] * scale_ + (zero_ - 32896.f * scale_);
}
}
return vo;
}
};
template<>
struct ConvertKvCache<int8_t, half> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
{
zero_ = zero_ - 32896.f * scale_;
}
template<int N>
inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<half, N>
{
Array<half, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 4) {
auto& vec = (Array<half, 4>&)vo[i];
auto tmp = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
vec[j] = half(tmp[j] * scale_ + zero_);
// vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
}
}
return vo;
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "decoder_multihead_attention_template.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include <iostream>
namespace turbomind {
namespace {
template<typename MHAType>
bool Print(size_t dynamic_smem_size)
{
using MapKv = typename MHAType::MapKv;
std::cout << " warps: " << MapKv::kWarpCount << "\n";
std::cout << " shape: (" << MapKv::kC << ", " << MapKv::kS << ")\n";
std::cout << " access: (" << MapKv::kAccessC << ", " << 1 << ")\n";
std::cout << "warpThread: (" << MapKv::kWarpThreadC << ", " << MapKv::kWarpThreadS << ")\n";
std::cout << "warpAccess: (" << MapKv::kWarpAccessC << ", " << MapKv::kWarpAccessS << ")\n";
std::cout << " warpIter: (" << MapKv::kWarpIterC << ", " << MapKv::kWarpIterS << ")\n";
std::cout << " warp: (" << MapKv::kWarpC << ", " << MapKv::kWarpS << ")\n";
std::cout << " iter: (" << MapKv::kIterC << ", " << MapKv::kIterS << ")\n";
std::cout << " footprint: (" << MapKv::kFootprintC << ", " << MapKv::kFootprintS << ")\n";
std::cout << " delta: (" << MapKv::kDeltaC << ", " << MapKv::kDeltaS << ")\n";
std::cout << "dynamic smem size: " << dynamic_smem_size << "\n";
return true;
}
} // namespace
template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
{
auto invoke = [&](auto* type) {
using Attn = std::remove_reference_t<decltype(*type)>;
static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();
// [[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);
const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
dim3 block(Attn::kWarpCount * WARP_SIZE);
dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
// if (params.layer_offset == 0) {
// std::cout << "max_split_k' = " << max_split_k << ", arch = " << params.arch << "\n";
// }
cudaFuncSetAttribute(
decoder_multihead_attention<Attn>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
decoder_multihead_attention<Attn><<<grid, block, kDynSmemSize, params.stream>>>(params);
if (max_split_k > 1) {
dim3 grid(params.num_heads, params.batch_size);
decoder_multihead_attention_reduce<Attn><<<grid, block, 0, params.stream>>>(params);
}
};
if (params.arch >= 80) {
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>; // 64k
using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
invoke((Type*)0);
}
else {
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>; // 34k
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 2048, 3>; // 34k
using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 1024, 3, true>;
invoke((Type*)0);
}
}
template<typename T>
void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
{
static constexpr int HeadDim = 128;
FT_CHECK(params.size_per_head == HeadDim);
if constexpr (std::is_same_v<T, half>) {
if (params.quant_policy & QuantPolicy::kCacheKVInt8) {
invokeDecoderMultiheadAttention<T, int8_t, HeadDim, 1>(params);
return;
}
int group_size = params.num_heads / params.num_kv_heads;
if (0) {}
// else if (group_size % 8 == 0) {
// invokeDecoderMultiheadAttention<T, T, HeadDim, 8>(params);
// }
else if (group_size % 4 == 0) {
invokeDecoderMultiheadAttention<T, T, HeadDim, 4>(params);
}
else if (group_size % 2 == 0) {
invokeDecoderMultiheadAttention<T, T, HeadDim, 2>(params);
}
else {
invokeDecoderMultiheadAttention<T, T, HeadDim, 1>(params);
}
}
}
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<half>& params);
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<float>& params);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "decoder_multihead_attention_params.h"
namespace turbomind {
template<typename T>
void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params);
}
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cuda_runtime.h>
namespace turbomind {
template<typename T>
struct DecoderMultiHeadAttentionParams {
// token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]
T* __restrict__ out;
T* __restrict__ q;
T* __restrict__ k;
T* __restrict__ v;
int stride;
// bias, [qH, D] or [kvH, D]
T* __restrict__ q_bias;
T* __restrict__ k_bias;
T* __restrict__ v_bias;
// sequence-level buffers
const int* __restrict__ per_sample_length;
const bool* __restrict__ finished;
const float* __restrict__ rope_theta;
// kv cache
void** __restrict__ per_sample_k_cache; // [H, S, D]
void** __restrict__ per_sample_v_cache; // [H, S, D]
size_t layer_offset;
/// cache layout M,[N,H,x,D]
/// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
/// 1. [L,sum(S),H,x,D]
void** __restrict__ k_cache_block_ptrs; // X,[H,x,D]
void** __restrict__ v_cache_block_ptrs; // X,[H,x,D]
int* __restrict__ cu_block_cnts; // [B+1]
int kv_cache_block_size;
// batch-level params
int batch_size;
int max_seq_len;
// instance-level params
int num_heads;
int num_kv_heads;
int size_per_head;
float inv_sqrt_dh;
// rotary embedding
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
// bool use_dynamic_ntk;
// log(n) attention
bool use_logn_attn;
int quant_policy;
float kv_quant_params[4];
int max_split_k;
float* partial_O;
float* partial_M;
float* partial_L;
int arch;
cudaStream_t stream;
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "array_ops.h"
#include "iterator.h"
#include "src/turbomind/kernels/gemm_s_f16/common.h"
#include "thread_map.h"
#include <climits>
#include <cmath>
#include <cstdint>
#include <cuda_pipeline_primitives.h>
#include <type_traits>
#include "decoder_multihead_attention_params.h"
namespace turbomind {
template<typename T,
typename Tkv,
int HeadPerCta,
int MaxHeadDim,
int KeyPerIter,
int HeadDim,
int SliceLen,
int Stages,
bool SplitK>
struct DecoderMultiHeadAttentionKernel {
using ParamType = DecoderMultiHeadAttentionParams<T>;
static constexpr int kWarpCount = 4;
static constexpr int kHeadPerCta = HeadPerCta;
static constexpr int kMaxHeadDim = MaxHeadDim;
static constexpr int kKeyPerIter = KeyPerIter;
static constexpr int kHeadDim = HeadDim;
static constexpr int kStages = Stages;
static constexpr bool kSplitK = SplitK;
static constexpr int kSliceLen = SliceLen;
static constexpr int kIterPerSlice = kSliceLen / kKeyPerIter;
static constexpr int kVecKvSize = sizeof(uint4) / sizeof(Tkv);
static constexpr int kThreadPerKey = 8;
using VecKv = Array<T, kVecKvSize>;
using VecKvFloat = Array<float, kVecKvSize>;
static constexpr bool kUseBlockIter = true;
using MapKv = ThreadMapKv<kMaxHeadDim, kKeyPerIter, kVecKvSize, kThreadPerKey, kWarpCount>;
using IterKv = turbomind::Iterator<Tkv, MapKv, SliceLen, kStages, kUseBlockIter>;
static constexpr size_t GetDynamicSmemSize()
{
size_t smem_kv_cache = IterKv::kSmemByteSize;
// size_t smem_kv_align = 128;
size_t smem_kv_align = 0;
size_t smem_qk = sizeof(float) * kHeadPerCta * kSliceLen;
size_t smem_pr = sizeof(float) * kHeadPerCta * kSliceLen;
return smem_kv_align + smem_kv_cache + std::max(smem_qk, smem_pr);
}
using QkAccumType = float;
using QkComputeType = float;
using PvAccumType = float;
using PvComputeType = float;
struct SharedStorage {
__align__(16) T Q[kHeadPerCta * kMaxHeadDim];
__align__(16) float O[kHeadPerCta * kMaxHeadDim];
float M[kHeadPerCta]; // max{dot(Q, K^T )}
float L[kHeadPerCta]; // sum{exp(s - S_max)}
float red_max[kHeadPerCta * kWarpCount];
float red_sum[kHeadPerCta * kWarpCount];
};
const ParamType& params_;
int head_idx_;
int batch_idx_;
int warp_id_;
int lane_id_;
int kv_head_idx_;
bool is_gqa_leader_;
int step_begin_;
int step_end_;
int timestep_;
Tkv* __restrict__ k_cache_; // [S, D]
Tkv* __restrict__ v_cache_; // [S, D]
const void** __restrict__ k_cache_ptrs_;
const void** __restrict__ v_cache_ptrs_;
Tkv* __restrict__ smem_Kv_;
float* __restrict__ smem_S_;
float* __restrict__ smem_P_;
T* __restrict__ smem_Q_;
float* __restrict__ smem_M_;
float* __restrict__ smem_L_;
float* __restrict__ smem_O_;
float* __restrict__ smem_red_max_;
float* __restrict__ smem_red_sum_;
// avoid redundant type cast for KV8
using KLoadType = std::conditional_t<std::is_same_v<Tkv, int8_t>, float, T>;
using VLoadType = std::conditional_t<std::is_same_v<Tkv, int8_t>, float, T>;
ConvertKvCache<T, Tkv> conv_k_store_;
ConvertKvCache<T, Tkv> conv_v_store_;
ConvertKvCache<Tkv, KLoadType> conv_k_;
ConvertKvCache<Tkv, VLoadType> conv_v_;
__device__ bool thread0()
{
return blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0;
}
__device__ DecoderMultiHeadAttentionKernel(const ParamType& params, SharedStorage& smem, uint8_t* dsmem):
params_(params),
conv_k_store_{params_.kv_quant_params[0], params_.kv_quant_params[1]},
conv_v_store_{params_.kv_quant_params[2], params_.kv_quant_params[3]},
conv_k_{params_.kv_quant_params[0], params_.kv_quant_params[1]},
conv_v_{params_.kv_quant_params[2], params_.kv_quant_params[3]}
{
smem_Kv_ = (Tkv*)dsmem;
smem_S_ = (float*)(smem_Kv_ + IterKv::kSizePerTile * kStages); // [HeadPerCta * kSliceLen]
smem_P_ = smem_S_; // ! reusing only works when S and P has same dtype
smem_Q_ = smem.Q;
smem_M_ = smem.M;
smem_L_ = smem.L;
smem_O_ = smem.O;
smem_red_max_ = smem.red_max;
smem_red_sum_ = smem.red_sum;
head_idx_ = blockIdx.x * kHeadPerCta;
batch_idx_ = blockIdx.y;
warp_id_ = threadIdx.x / WARP_SIZE;
lane_id_ = threadIdx.x % WARP_SIZE;
const int gqa_group_size = params.num_heads / params.num_kv_heads;
kv_head_idx_ = head_idx_ / gqa_group_size;
is_gqa_leader_ = head_idx_ % gqa_group_size == 0;
timestep_ = params_.per_sample_length[batch_idx_];
if (kSplitK && params.max_split_k > 1) {
const int slice_count = (timestep_ + kSliceLen - 1) / kSliceLen;
const int slice_per_split = (slice_count + params_.max_split_k - 1) / params_.max_split_k;
step_begin_ = slice_per_split * get_split_k_idx() * kSliceLen;
step_end_ = min(timestep_, step_begin_ + slice_per_split * kSliceLen);
}
else {
step_begin_ = 0;
step_end_ = timestep_;
}
if constexpr (kUseBlockIter) {
k_cache_ptrs_ = params_.k_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
v_cache_ptrs_ = params_.v_cache_block_ptrs + params_.cu_block_cnts[batch_idx_];
}
else {
k_cache_ = (T*)params_.per_sample_k_cache[batch_idx_] + params.layer_offset
+ kv_head_idx_ * params_.max_seq_len * params_.size_per_head;
v_cache_ = (T*)params_.per_sample_v_cache[batch_idx_] + params.layer_offset
+ kv_head_idx_ * params_.max_seq_len * params_.size_per_head;
}
}
__device__ void Prolugue()
{
// - Each warp is handling a row of Q
// - K/V are loaded redundantly only for the current step
static_assert(kMaxHeadDim % WARP_SIZE == 0);
static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
using VecQ = Array<T, kVecQSize>;
using VecQFloat = Array<float, kVecQSize>;
using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
static constexpr int kQVecPerThread = MapQ::kIterC;
static constexpr int kQHeadPerThread = MapQ::kIterS; // > 1 when #warp < kCtaPerHead
static_assert(kQVecPerThread == 1);
int2 offset = MapQ::get_offset(warp_id_, lane_id_);
bool is_valid = offset.x < kMaxHeadDim && offset.y < kHeadPerCta;
if (!is_valid) {
return;
}
VecQ frag_Q[kQHeadPerThread];
VecQ frag_K;
VecQ frag_V;
// load qkv
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
int di = offset.x;
int qi = offset.y + s;
Ldg(frag_Q[s], &params_.q[batch_idx_ * params_.stride + (head_idx_ + qi) * kHeadDim + di]);
}
Ldg(frag_K, &params_.k[batch_idx_ * params_.stride + kv_head_idx_ * kHeadDim + offset.x]);
Ldg(frag_V, &params_.v[batch_idx_ * params_.stride + kv_head_idx_ * kHeadDim + offset.x]);
if (params_.q_bias) {
// load biases
VecQ bias_Q[kQHeadPerThread];
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
int di = offset.x;
int qi = offset.y + s;
Ldg(bias_Q[s], &params_.q_bias[(head_idx_ + qi) * kHeadDim + di]);
}
VecQ bias_K;
VecQ bias_V;
Ldg(bias_K, &params_.k_bias[kv_head_idx_ * kHeadDim + offset.x]);
Ldg(bias_V, &params_.v_bias[kv_head_idx_ * kHeadDim + offset.x]);
using namespace ops;
// apply biases
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
frag_Q[s] = frag_Q[s] + bias_Q[s];
}
frag_K = frag_K + bias_K;
frag_V = frag_V + bias_V;
}
// for (int i = 0; i < kVecQSize; ++i) {
// printf("q[%2d][%3d] = %f\n", (int)head_idx_, (int)(offset.x + i), (float)frag_Q[0][i]);
// }
float rotary_embedding_base =
params_.rope_theta ? params_.rope_theta[batch_idx_] : params_.rotary_embedding_base;
// Apply rotary embedding
RotaryEmbedding<kVecQSize> rotary_emb(rotary_embedding_base, params_.rotary_embedding_dim, timestep_, offset);
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
rotary_emb.apply(frag_Q[s]);
}
rotary_emb.apply(frag_K);
if (params_.use_logn_attn) {
LogNScaling logn_scaling(timestep_ + 1, params_.max_position_embeddings);
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
logn_scaling.apply(frag_Q[s]);
}
}
if (kSplitK && step_begin_) { // Split idx > 0
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
int qi = offset.y + s;
if (lane_id_ == 0) {
smem_M_[qi] = -std::numeric_limits<float>::infinity();
smem_L_[qi] = 0.f;
}
Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
Store(&smem_O_[qi * kMaxHeadDim + offset.x], VecQFloat{});
}
return;
}
////////////////////////////////////////////////////////
// Split 0 computes last step and stores to k/v cache
PRAGMA_UNROLL
for (int s = 0; s < kQHeadPerThread; ++s) {
int qi = offset.y + s;
QkAccumType qk = qk_dot<QkAccumType, QkComputeType, WARP_SIZE>(frag_Q[s], frag_K);
if (lane_id_ == 0) {
qk *= params_.inv_sqrt_dh;
smem_M_[qi] = qk;
smem_L_[qi] = 1.f;
// printf("qk[%2d] = %f\n", head_idx_, qk);
}
// write Q and O
Store(&smem_Q_[qi * kMaxHeadDim + offset.x], frag_Q[s]);
Store(&smem_O_[qi * kMaxHeadDim + offset.x], cast<float>(frag_V));
}
auto frag_K_store = conv_k_store_(frag_K);
auto frag_V_store = conv_v_store_(frag_V);
// store
if (warp_id_ == 0 && is_gqa_leader_) {
if constexpr (kUseBlockIter) {
int block_index = timestep_ / params_.kv_cache_block_size;
int block_offset = timestep_ % params_.kv_cache_block_size;
// if (thread0()) {
// printf("%d %d %p %p\n", block_index, block_offset, k_cache_ptrs_, v_cache_ptrs_);
// }
k_cache_ = (Tkv*)k_cache_ptrs_[block_index] + params_.layer_offset
+ kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
v_cache_ = (Tkv*)v_cache_ptrs_[block_index] + params_.layer_offset
+ kv_head_idx_ * params_.kv_cache_block_size * kHeadDim;
Store(&k_cache_[block_offset * kHeadDim + offset.x], frag_K_store);
Store(&v_cache_[block_offset * kHeadDim + offset.x], frag_V_store);
}
else {
Store(&k_cache_[timestep_ * kHeadDim + offset.x], frag_K_store);
Store(&v_cache_[timestep_ * kHeadDim + offset.x], frag_V_store);
}
}
}
__device__ void PrefetchKvCache(IterKv& iter)
{
PRAGMA_UNROLL
for (int stage = 0; stage < kStages - 1; ++stage) {
iter.PrefetchStage();
CpAsyncCommit();
}
}
__device__ void CpAsyncWait()
{
__pipeline_wait_prior(kStages - 2);
}
__device__ void CpAsyncCommit()
{
__pipeline_commit();
}
__device__ void CpAsyncFlush()
{
__pipeline_commit();
__pipeline_wait_prior(0);
}
static constexpr int kKvVecPerThread = MapKv::kIterC;
static constexpr int kKvKeyPerThread = MapKv::kIterS;
struct FragmentQ {
VecKv data[kHeadPerCta][kKvVecPerThread];
};
struct State {
// Double buffering to hide smem/dequant latency
Array<KLoadType, kVecKvSize> frag_K_buf[2][kKvVecPerThread];
Array<VLoadType, kVecKvSize> frag_V_buf[2][kKvVecPerThread];
Array<Tkv, kVecKvSize> frag_Kv_tmp_buf[2][kKvVecPerThread];
};
static constexpr int kPrefetchCount = (IterKv::kIterCount + MapKv::kIterS - 1) / MapKv::kIterS;
__device__ void ComputeSlice(FragmentQ& frag_Q, State& state, const int2& offset, int step, int iter_length)
{
Array<float, kHeadPerCta> frag_M;
PRAGMA_UNROLL
for (int i = 0; i < kHeadPerCta; ++i) {
frag_M[i] = smem_M_[i];
}
IterKv iter_K;
if constexpr (kUseBlockIter) {
iter_K = {k_cache_ptrs_,
params_.kv_cache_block_size,
params_.layer_offset,
kv_head_idx_,
smem_Kv_,
step,
step + iter_length,
warp_id_,
lane_id_};
}
else {
iter_K = {k_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_};
}
PrefetchKvCache(iter_K);
CpAsyncWait();
iter_K.Load(state.frag_Kv_tmp_buf[0]);
PRAGMA_UNROLL
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
state.frag_K_buf[0][vi] = conv_k_(state.frag_Kv_tmp_buf[0][vi]);
}
iter_K.PrefetchBatch(0, kPrefetchCount);
if (kKvKeyPerThread == 1) {
CpAsyncCommit();
CpAsyncWait();
iter_K.AdvancePrefetchStage();
iter_K.AdvanceComputeStage();
}
///////////////////////////////////////////////////////////////////////////////////////////
/// Compute QK(Q, S) = Q(Q, D) * K^T(D, S)
PRAGMA_NO_UNROLL
for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
PRAGMA_UNROLL
for (int si = 0; si < kKvKeyPerThread; ++si) {
const int next = (si + 1) % 2;
// smem -> rmem for next iter
iter_K.Load(state.frag_Kv_tmp_buf[next]);
PRAGMA_UNROLL
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
state.frag_K_buf[next][vi] = conv_k_(state.frag_Kv_tmp_buf[next][vi]);
}
// current iter's K fragment
auto& frag_K = state.frag_K_buf[si % 2];
const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
auto qk = qk_dot<QkAccumType, QkComputeType, kThreadPerKey>(frag_Q.data[qi], frag_K);
// if (ti == 16) {
// for (int vi = 0; vi < kKvVecPerThread; ++vi) {
// for (int i = 0; i < kVecKvSize; ++i) {
// printf("frag_Q = %f, frag_K[%d] = %f\n",
// (float)frag_Q.data[qi][vi][i],
// offset.x + vi * kVecKvSize + i,
// (float)frag_K[vi][i]);
// }
// }
// }
qk *= params_.inv_sqrt_dh;
if (step + local_offset < timestep_) {
// group leader writes to smem
if (threadIdx.x % kThreadPerKey == 0) {
// printf("qk_%d = %f\n", step + local_offset, (float)qk);
smem_S_[kSliceLen * qi + local_offset] = qk;
// local max
frag_M[qi] = fmaxf(frag_M[qi], qk);
}
}
}
iter_K.PrefetchBatch((si + 1) % kKvKeyPerThread, kPrefetchCount);
if (kKvKeyPerThread == 1 || si == kKvKeyPerThread - 2) {
CpAsyncCommit();
CpAsyncWait();
iter_K.AdvancePrefetchStage();
iter_K.AdvanceComputeStage();
}
}
// handle special case
if (kKvKeyPerThread == 1) {
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
state.frag_K_buf[0][vi] = state.frag_K_buf[1][vi];
}
}
}
CpAsyncFlush();
__syncthreads();
Array<float, kHeadPerCta> exp_M_diff;
PRAGMA_UNROLL
for (int i = 0; i < kHeadPerCta; ++i) {
exp_M_diff[i] = smem_M_[i];
}
/// block synchronization
frag_M = qk_max<MapKv>(frag_M, smem_red_max_, warp_id_, lane_id_);
// wait while smem_red_ is being used.
// __syncthreads();
PRAGMA_UNROLL
for (int i = 0; i < kHeadPerCta; ++i) {
// if (thread0()) {
// printf("%f %f %f\n", (float)exp_M_diff[i], (float)frag_M[i], (float)__expf(exp_M_diff[i] -
// frag_M[i]));
// }
// exp(m1 - m2)
exp_M_diff[i] = __expf(exp_M_diff[i] - frag_M[i]);
if (threadIdx.x == 0) {
smem_M_[i] = frag_M[i];
}
}
// if (threadIdx.x == 0 && step + iter_length == timestep_) {
// printf("frag_M[%2d] = %f\n", head_idx_, (float)frag_M[0]);
// }
// __syncthreads(); // DEBUG
/////////////////////////////////////////////////////////////////////////////////////////
// / Compute softmax P(Q, S)
Array<float, kHeadPerCta> frag_L{};
for (int ti = threadIdx.x; ti < iter_length; ti += kWarpCount * WARP_SIZE) {
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
int idx = qi * kSliceLen + ti;
float qk = smem_S_[idx];
float pr = expf(qk - frag_M[qi]);
// printf("smem_P[%d] = %f\n", ti, pr);
smem_P_[idx] = pr;
frag_L[qi] += pr;
}
}
/// block synchronization
frag_L = blockSum<kWarpCount>(frag_L, smem_red_sum_, warp_id_, lane_id_);
for (int qi = 0; qi < kHeadPerCta; ++qi) {
// exp(m1 - m2) * l1
frag_L[qi] += exp_M_diff[qi] * smem_L_[qi];
}
__syncthreads();
for (int qi = 0; qi < kHeadPerCta; ++qi) {
if (threadIdx.x == 0) {
smem_L_[qi] = frag_L[qi];
}
}
if (threadIdx.x == 0 && step == timestep_ - kSliceLen) {
// printf("frag_L'[%d] = %f\n", head_idx_, (float)frag_L[0]);
}
/////////////////////////////////////////////////////////////////////////////////////////
// / Compute O[H,D] = P[H,S] * V[S,D]
VecKvFloat frag_O[kHeadPerCta][kKvVecPerThread]{}; // value initialize
// float frag_Pr_buf[2][kHeadPerCta];
// ti = step + offset.y;
// int ti = step + offset.y;
// PRAGMA_UNROLL
// for (int qi = 0; qi < kHeadPerCta; ++qi) {
// // prefetch Pr for first warp iter
// frag_Pr_buf[0][qi] = smem_P_[qi * kSliceLen + ti];
// }
IterKv iter_V;
if constexpr (kUseBlockIter) {
iter_V = {v_cache_ptrs_,
params_.kv_cache_block_size,
params_.layer_offset,
kv_head_idx_,
smem_Kv_,
step,
step + iter_length,
warp_id_,
lane_id_};
}
else {
iter_V = {v_cache_, smem_Kv_, step, step + iter_length, warp_id_, lane_id_};
}
PrefetchKvCache(iter_V);
CpAsyncWait();
iter_V.Load(state.frag_Kv_tmp_buf[0]);
PRAGMA_UNROLL
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
state.frag_V_buf[0][vi] = conv_v_(state.frag_Kv_tmp_buf[0][vi]);
}
iter_V.PrefetchBatch(0, kPrefetchCount);
if (kKvKeyPerThread == 1) {
CpAsyncCommit();
CpAsyncWait();
iter_V.AdvancePrefetchStage();
iter_V.AdvanceComputeStage();
}
PRAGMA_NO_UNROLL
for (int _it = 0; _it < iter_length; _it += kKeyPerIter) {
PRAGMA_UNROLL
for (int si = 0; si < kKvKeyPerThread; ++si) {
const int next = (si + 1) % 2;
// Load value cache for next warp iter
iter_V.Load(state.frag_Kv_tmp_buf[next]);
PRAGMA_UNROLL
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
state.frag_V_buf[next][vi] = conv_v_(state.frag_Kv_tmp_buf[next][vi]);
}
// Load Pr for next warp iter
// PRAGMA_UNROLL
// for (int qi = 0; qi < kHeadPerCta; ++qi) {
// frag_Pr_buf[(si + 1) % 2][qi] = smem_P_[qi * kSliceLen + (ti + MapKv::kWarpAccessS)];
// }
auto& frag_V = state.frag_V_buf[si % 2];
// auto& frag_P = frag_Pr_buf[si % 2];
const int local_offset = offset.y + _it + si * MapKv::kWarpAccessS;
float frag_P[kHeadPerCta];
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
frag_P[qi] = smem_P_[qi * kSliceLen + local_offset];
}
if (step + local_offset < timestep_) {
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
fma_pv<PvComputeType>(frag_P[qi], frag_V, frag_O[qi]);
}
// for (int i = 0; i < kKvVecPerThread; ++i) {
// for (int j = 0; j < kVecKvSize; ++j) {
// printf("frag_V %f\n", (float)frag_V[i][j]);
// }
// }
// if (threadIdx.x % MapKv::kWarpThreadC == 0) {
// printf("frag_P[%d] %f\n", ti, frag_P[0]);
// }
}
iter_V.PrefetchBatch((si + 1) % kKvKeyPerThread, kPrefetchCount);
if (kKvKeyPerThread == 1 || si == kKvKeyPerThread - 2) {
CpAsyncCommit();
CpAsyncWait();
iter_V.AdvancePrefetchStage();
iter_V.AdvanceComputeStage();
}
}
// handle special case
if (kKvKeyPerThread == 1) {
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
state.frag_V_buf[0][vi] = state.frag_V_buf[1][vi];
}
// PRAGMA_UNROLL
// for (int qi = 0; qi < kHeadPerCta; ++qi) {
// frag_Pr_buf[0][qi] = frag_Pr_buf[1][qi];
// }
}
}
/// warp reduce over S dim
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
PRAGMA_UNROLL
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
PRAGMA_UNROLL
for (int i = 0; i < kVecKvSize; ++i) {
// reduce over warp thread S
PRAGMA_UNROLL
for (int mask = WARP_SIZE / 2; mask >= MapKv::kWarpThreadC; mask /= 2) {
frag_O[qi][vi][i] += __shfl_xor_sync(uint32_t(-1), frag_O[qi][vi][i], mask);
}
}
}
}
// __syncthreads();
PRAGMA_UNROLL
for (int gi = 0; gi < MapKv::kS; gi += MapKv::kFootprintS) {
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
PRAGMA_UNROLL
for (int vi = 0; vi < kKvVecPerThread; ++vi) {
if (offset.y == gi) {
// bank conflict
auto& smem_O = (VecKvFloat&)smem_O_[qi * kMaxHeadDim + offset.x + vi * MapKv::kDeltaC];
using namespace ops;
auto tmp_O = smem_O;
if (offset.y == 0) {
tmp_O = tmp_O * exp_M_diff[qi];
}
// bank conflict
smem_O = tmp_O + frag_O[qi][vi];
}
}
}
__syncthreads();
}
CpAsyncFlush();
}
__device__ void LoopKv()
{
const int2 offset = MapKv::get_offset(warp_id_, lane_id_);
///////////////////////////////////////////////////////////////////////////////////////////
/// Load Q from shared memory.
/// NOTE: There will be bank-conflict when sizeof(VecKv) > 16 (e.g. KV is quantized)
FragmentQ frag_Q;
PRAGMA_UNROLL
for (int qi = 0; qi < kHeadPerCta; ++qi) {
PRAGMA_UNROLL
for (int c = 0; c < kKvVecPerThread; ++c) {
const int di = offset.x + MapKv::kDeltaC * c;
frag_Q.data[qi][c] = (VecKv&)smem_Q_[qi * kMaxHeadDim + di];
}
}
State state;
PRAGMA_NO_UNROLL
for (int step = step_begin_; step < step_end_; step += kSliceLen) {
int iter_count = min(step_end_ - step, kSliceLen);
ComputeSlice(frag_Q, state, offset, step, iter_count);
}
}
__device__ void Run()
{
if constexpr (0) {
for (int i = threadIdx.x; i < kStages * IterKv::kSizePerTile; i += blockDim.x) {
smem_Kv_[i] = T(0);
}
__syncthreads();
}
// early exit if split if out of bound
if (kSplitK && step_begin_ >= step_end_) {
return;
}
// early exit if finished flag is set
if (params_.finished[batch_idx_]) {
return;
}
// Compute attention for current step
Prolugue();
__syncthreads();
// Iterate over K/V
LoopKv();
__syncthreads();
// Normalize outputs & write to device memory
Epilogue();
}
__device__ void Epilogue()
{
static constexpr int kVecQSize = kMaxHeadDim / WARP_SIZE;
using VecQFloat = Array<float, kVecQSize>;
using MapQ = ThreadMapQ<kMaxHeadDim, kHeadPerCta, kVecQSize, kWarpCount>;
static constexpr int kQkvHeadPerThread = MapQ::kIterS;
int2 offset = MapQ::get_offset(warp_id_, lane_id_);
if (offset.x >= kMaxHeadDim || offset.y >= kHeadPerCta) {
return;
}
using namespace ops;
if (!kSplitK || (step_begin_ == 0 && step_end_ == timestep_)) { // non-split-k
PRAGMA_UNROLL
for (int s = 0; s < kQkvHeadPerThread; ++s) {
const int di = offset.x;
const int qi = offset.y + s;
const float scale = __fdividef(1.f, smem_L_[qi] + 1e-8f);
const VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di] * scale;
Store(&params_.out[batch_idx_ * params_.num_heads * kHeadDim + (head_idx_ + qi) * kHeadDim + di],
cast<T>(frag_O));
}
}
else {
PRAGMA_UNROLL
for (int s = 0; s < kQkvHeadPerThread; ++s) { // split-k
const int di = offset.x;
const int qi = offset.y + s;
const VecQFloat frag_O = (VecQFloat&)smem_O_[qi * kMaxHeadDim + di];
// [B, H, k, D]
const int index = batch_idx_ * params_.num_heads * params_.max_split_k
+ (head_idx_ + qi) * params_.max_split_k + get_split_k_idx();
Store(&params_.partial_O[index * kHeadDim + di], cast<float>(frag_O));
if (di == 0) {
params_.partial_M[index] = smem_M_[qi];
params_.partial_L[index] = smem_L_[qi];
}
}
}
}
static __device__ void Reduce(const ParamType& params)
{
const int batch_idx = get_batch_idx();
const int head_idx = get_head_idx();
const int timestep = params.per_sample_length[batch_idx];
const int max_split_k = params.max_split_k;
const int slice_count = get_slice_count(timestep);
const int slice_per_split = (slice_count + max_split_k - 1) / max_split_k;
const int split_k = (slice_count + slice_per_split - 1) / slice_per_split;
if (split_k == 1) {
return;
}
// [B, H, k, D]
const int index = batch_idx * params.num_heads * max_split_k + head_idx * max_split_k + threadIdx.x;
__shared__ float smem_global_M;
__shared__ float smem_global_L;
__shared__ __align__(16) float smem_expdiff_M[WARP_SIZE];
__shared__ __align__(16) float smem_scale_O[WARP_SIZE];
{
float global_M = threadIdx.x < split_k ? params.partial_M[index] : -std::numeric_limits<float>::infinity();
PRAGMA_UNROLL
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_M = fmaxf(global_M, __shfl_xor_sync((uint32_t)-1, global_M, mask));
}
if (threadIdx.x == 0) {
smem_global_M = global_M;
}
}
__syncthreads();
{
float global_L = threadIdx.x < split_k ? params.partial_L[index] : 0.f;
if (threadIdx.x < split_k) {
auto expdiff_M = expf(params.partial_M[index] - smem_global_M);
global_L *= expdiff_M;
smem_expdiff_M[threadIdx.x] = expdiff_M;
}
PRAGMA_UNROLL
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_L += __shfl_xor_sync((uint32_t)-1, global_L, mask);
}
if (threadIdx.x == 0) {
smem_global_L = global_L;
}
}
__syncthreads();
if (threadIdx.x < split_k) {
smem_scale_O[threadIdx.x] = smem_expdiff_M[threadIdx.x] / (smem_global_L + 1e-8f);
}
__syncthreads();
int idx = (batch_idx * params.num_heads * max_split_k + head_idx * max_split_k) * kHeadDim + threadIdx.x;
float accum_O{};
const bool is_valid = threadIdx.x < kHeadDim;
for (int k = 0; k < split_k; ++k) {
if (is_valid) {
accum_O += smem_scale_O[k] * params.partial_O[idx];
}
idx += kHeadDim;
}
if (is_valid) {
params.out[batch_idx * params.num_heads * kHeadDim + head_idx * kHeadDim + threadIdx.x] = (T)accum_O;
}
}
static __device__ int get_slice_count(int timestep)
{
return (timestep + kSliceLen - 1) / kSliceLen;
}
static __device__ int get_head_idx()
{
return blockIdx.x;
}
static __device__ int get_batch_idx()
{
return blockIdx.y;
}
static __device__ int get_split_k_idx()
{
return blockIdx.z;
}
};
extern __shared__ uint8_t dynamic_smem[];
template<typename MHAType, typename ParamType = typename MHAType::ParamType>
__global__ void decoder_multihead_attention(ParamType params)
{
__shared__ typename MHAType::SharedStorage shared_storage;
uint8_t* smem_ptr = dynamic_smem;
MHAType{params, shared_storage, smem_ptr}.Run();
}
template<typename MHAType, typename ParamType = typename MHAType::ParamType>
__global__ void decoder_multihead_attention_reduce(ParamType params)
{
MHAType::Reduce(params);
}
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "../gemm_s_f16/common.h"
#include "array_ops.h"
namespace turbomind {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
struct BlockIterator {
const void** ptrs_;
const void* prefetch_;
BlockIterator() = default;
__device__ BlockIterator(const void** block_ptrs): ptrs_{block_ptrs}
{
// prefetch first ptr
prefetch_ = *ptrs_++;
}
__device__ const void* Next()
{
// return prefetched ptr
const void* ret = prefetch_;
// prefetch next ptr
prefetch_ = *ptrs_++;
return ret;
}
};
template<typename T, typename ThreadMap, int BlockLen, int Stages, bool kUseBlockIter>
struct Iterator {
using ElementType = T;
using AccessType = Array<T, ThreadMap::kAccessC>;
static constexpr int kElementSize = sizeof(ElementType);
static constexpr int kAccessSize = sizeof(AccessType);
static constexpr int kSizePerTile = ThreadMap::kS * ThreadMap::kC;
static constexpr int kSmemByteSize = kElementSize * Stages * kSizePerTile;
BlockIterator block_iterator_;
static constexpr int kIterCount = ThreadMap::kIterS * ThreadMap::kIterC;
static constexpr int kStepC = ThreadMap::kDeltaC;
static constexpr int kStepS = ThreadMap::kDeltaS * ThreadMap::kC - ThreadMap::kIterC * kStepC;
static constexpr int kStepK =
ThreadMap::kS * ThreadMap::kC - ThreadMap::kIterS * ThreadMap::kDeltaS * ThreadMap::kC;
// (C, S, K) = (64, 384, 1536)
// initial offset, used to reset src_offset when switching to a new block
int init_offset_;
int src_offset_;
int dst_offset_;
int iter_c_;
int iter_b_;
int seq_len_;
int offset_s_;
bool is_valid_s_;
int block_size_;
int block_k_;
int layer_offset_;
int head_idx_;
const T* __restrict__ src_;
T* __restrict__ smem_;
int smem_read_offset_;
struct __align__(sizeof(AccessType)) SharedStorage
{
T smem_[Stages][kSizePerTile];
};
Iterator() = default;
__device__ Iterator(T* src, T* smem, int step, int seq_len, int warp_id, int lane_id)
{
src_ = src;
smem_ = smem;
int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
src_offset_ = init_offset_ + step * ThreadMap::kC;
dst_offset_ = init_offset_;
smem_read_offset_ = init_offset_;
iter_c_ = 0;
iter_b_ = 0;
seq_len_ = seq_len;
offset_s_ = init_offset_cs.y + step;
is_valid_s_ = offset_s_ < seq_len;
}
__device__ Iterator(const void** block_ptrs,
int block_size,
int layer_offset,
int head_idx,
T* smem,
int step,
int seqlen,
int warp_id,
int lane_id)
{
// src_ = src;
int block_index = step / block_size;
block_size_ = block_size;
block_k_ = (block_index + 1) * block_size - step; // offset to next block
layer_offset_ = layer_offset;
head_idx_ = head_idx;
block_iterator_ = BlockIterator(block_ptrs + block_index);
src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
smem_ = smem;
int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
src_offset_ = init_offset_ + (step - block_index * block_size) * ThreadMap::kC;
dst_offset_ = init_offset_;
smem_read_offset_ = init_offset_;
iter_c_ = 0;
iter_b_ = 0;
seq_len_ = seqlen;
offset_s_ = init_offset_cs.y + step;
is_valid_s_ = offset_s_ < seqlen;
}
__device__ void PrefetchStage()
{
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
Prefetch(is_valid_s_);
++(*this);
}
AdvancePrefetchStage();
}
__device__ void PrefetchBatch(int batch_idx, int batch_size)
{
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
Prefetch(is_valid_s_);
++(*this);
}
}
}
__device__ Iterator& operator++()
{
src_offset_ += kStepC;
dst_offset_ += kStepC;
++iter_c_;
if (iter_c_ < ThreadMap::kIterC) {
return *this;
}
iter_c_ = 0;
src_offset_ += kStepS;
dst_offset_ += kStepS;
offset_s_ += ThreadMap::kDeltaS;
is_valid_s_ = offset_s_ < seq_len_;
return *this;
}
__device__ void AdvancePrefetchStage()
{
src_offset_ += kStepK;
dst_offset_ += kStepK;
offset_s_ += ThreadMap::kS - ThreadMap::kIterS * ThreadMap::kDeltaS;
is_valid_s_ = offset_s_ < seq_len_;
if constexpr (kUseBlockIter) {
if (is_valid_s_) {
block_k_ -= ThreadMap::kS;
if (block_k_ == 0) {
src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
block_k_ = block_size_;
src_offset_ = init_offset_;
}
}
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// printf("%d %d %d\n", offset_s_, src_offset_ / ThreadMap::kC, block_k_);
// }
}
// if (init_offset_ / ThreadMap::kC == 0) {
// int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
// int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
// int c = dst_offset_ % ThreadMap::kC;
// printf("tid=%d, k=%d, s=%d, c=%d, offset_s=%d, valid_s=%d, init_s=%d\n",
// threadIdx.x,
// k,
// s,
// c,
// offset_s_,
// (int)is_valid_s_,
// init_offset_ / ThreadMap::kC);
// }
// if (threadIdx.x == 0 && blockIdx.x == 0) {
// printf("next stage %d\n", offset_s_);
// }
if (dst_offset_ >= Stages * kSizePerTile) {
dst_offset_ -= Stages * kSizePerTile;
}
// if constexpr (Chained) {
// bool is_last_stage = *signal_iterator_;
// ++signal_iterator_;
// if (is_last_stage) {
// AdvancePrefetchSlice();
// }
// }
}
#if 0
__device__ void AdvancePrefetchSlice()
{
src_ = (const T*)block_iterator_.Next();
src_offset_ = init_offset_;
++iter_b_;
offset_s_ = iter_b_ / 2 * BlockLen + init_offset_ / ThreadMap::kC;
is_valid_s_ = offset_s_ < seq_len_;
}
#endif
static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{
const int smem_int_ptr = cast_smem_ptr_to_uint(dst);
constexpr int cp_size = sizeof(AccessType);
#if TURBOMIND_ARCH_SM80
// clang-format off
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
// clang-format on
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{
if (mask) {
Ldg(*(AccessType*)dst, src);
}
}
__device__ void Prefetch(bool mask)
{
if constexpr (TURBOMIND_ARCH_SM80) {
CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
}
else {
Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
}
}
__device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
{
// if (init_offset_ / ThreadMap::kC == 0) {
// int k = smem_read_offset_ / (ThreadMap::kS * ThreadMap::kC);
// int s = smem_read_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
// int c = smem_read_offset_ % ThreadMap::kC;
// printf("tid=%d, k=%d, s=%d, c=%d, init_s=%d\n", threadIdx.x, k, s, c, init_offset_ / ThreadMap::kC);
// }
for (int vi = 0; vi < ThreadMap::kIterC; ++vi) {
// int offset = smem_read_offset_ + vi * ThreadMap::kDeltaC;
// if (offset >= Stages * kSizePerTile || offset % sizeof(AccessType)) {
// int c = offset % ThreadMap::kC;
// int s = offset / ThreadMap::kC;
// printf("%d %d %d\n", c, s, offset);
// }
Lds(frag[vi], smem_ + smem_read_offset_ + vi * ThreadMap::kDeltaC);
}
smem_read_offset_ += ThreadMap::kDeltaS * ThreadMap::kC;
}
__device__ void AdvanceComputeStage()
{
smem_read_offset_ += kStepK;
if (smem_read_offset_ >= Stages * kSizePerTile) {
smem_read_offset_ -= Stages * kSizePerTile;
}
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "../gemm_s_f16/common.h"
#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include <cuda_fp16.h>
#include <type_traits>
namespace turbomind {
// [S/x, H, x, D] <-> [S/y, H, y, D]
template<typename Tin,
typename Tout,
typename SrcBlockLen,
typename DstBlockLen,
typename HeadDim,
typename Transform = ConvertKvCache<Tin, Tout>>
__inline__ __device__ void ConvertBlockSize(const Tin** __restrict__ src_block_ptrs,
Tout** __restrict__ dst_block_ptrs,
const int* __restrict__ src_cu_block_cnts,
const int* __restrict__ dst_cu_block_cnts,
const int* __restrict__ seq_lens,
int src_offset,
int dst_offset,
SrcBlockLen src_block_len,
DstBlockLen dst_block_len,
HeadDim head_dim,
Transform transform = {1.f, 0.f})
{
constexpr int kVecSize = sizeof(uint4) / std::max(sizeof(Tin), sizeof(Tout));
const int hi = blockIdx.y;
const int bi = blockIdx.z;
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
/// TODO: use cutlass fast div/mod
const int di = idx * kVecSize % head_dim;
const int si = idx * kVecSize / head_dim;
if (si >= seq_lens[bi]) {
return;
}
// compute indices into src
int src_block_index = si / src_block_len + src_cu_block_cnts[bi];
int src_block_offset = src_offset + hi * src_block_len * head_dim + si % src_block_len * head_dim + di;
// compute indices into dst
int dst_block_index = si / dst_block_len + dst_cu_block_cnts[bi];
int dst_block_offset = dst_offset + hi * dst_block_len * head_dim + si % dst_block_len * head_dim + di;
// printf("%d %d\n", src_block_index, dst_block_index);
const Tin* __restrict__ src_block = src_block_ptrs[src_block_index];
Tout* __restrict__ dst_block = dst_block_ptrs[dst_block_index];
// uint4 data = __ldg(reinterpret_cast<const uint4*>(src_block + src_block_offset));
Array<Tin, kVecSize> src_vec;
Ldg(src_vec, src_block + src_block_offset);
Array<Tout, kVecSize> dst_vec = transform(src_vec);
Store(dst_block + dst_block_offset, dst_vec);
// *reinterpret_cast<uint4*>(dst_block + dst_block_offset) = data;
}
template<typename T>
__global__ void LinearToBlocksKernel(const T* src,
T** dst_block_ptrs,
const int* dst_cu_block_cnts,
const int* seq_lens,
int dst_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size)
{
extern __shared__ void* smem[];
const T** src_block_ptrs = (const T**)smem;
int* src_cu_block_cnts = (int*)(src_block_ptrs + batch_size);
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
src_cu_block_cnts[i] = i;
src_block_ptrs[i] = src + blockIdx.z * head_num * src_block_len * head_dim;
}
__syncthreads();
ConvertBlockSize(src_block_ptrs,
dst_block_ptrs,
src_cu_block_cnts,
dst_cu_block_cnts,
seq_lens,
0,
dst_offset,
src_block_len,
dst_block_len,
head_dim);
}
template<typename T>
void ConvertLinearToBlocks(const T* src,
T** dst_block_ptrs,
const int* dst_cu_block_cnts,
const int* seq_lens,
int dst_offset,
int src_max_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st)
{
constexpr int kVecSize = sizeof(uint4) / sizeof(T);
constexpr int threads = 128;
const dim3 blocks((src_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
const auto smem_sz = (sizeof(void*) + sizeof(int)) * batch_size;
auto fn = [&](auto head_dim) {
LinearToBlocksKernel<<<blocks, threads, smem_sz, st>>>(src,
dst_block_ptrs,
dst_cu_block_cnts,
seq_lens,
dst_offset,
src_max_len,
dst_block_len,
head_num,
head_dim,
batch_size);
};
switch (head_dim) {
case 128:
fn(std::integral_constant<int, 128>{});
break;
default:
fn(head_dim);
}
}
template void ConvertLinearToBlocks(const half* src,
half** dst_block_ptrs,
const int* dst_cu_block_cnts,
const int* seq_lens,
int dst_offset,
int src_seq_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st);
template<typename T, typename HeadDim>
__global__ void BlocksToLinearKernel(const T** src_block_ptrs,
T* dst,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
HeadDim head_dim,
int batch_size)
{
extern __shared__ void* smem[];
T** dst_block_ptrs = (T**)smem;
int* dst_cu_block_cnts = (int*)(dst_block_ptrs + batch_size);
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
dst_cu_block_cnts[i] = i;
dst_block_ptrs[i] = dst + blockIdx.z * head_num * dst_block_len * head_dim;
}
__syncthreads();
ConvertBlockSize(src_block_ptrs,
dst_block_ptrs,
src_cu_block_cnts,
dst_cu_block_cnts,
seq_lens,
src_offset,
0,
src_block_len,
dst_block_len,
head_dim);
}
template<typename T>
void ConvertBlocksToLinear(const T** src_block_ptrs,
T* dst,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_max_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st)
{
constexpr int kVecSize = sizeof(uint4) / sizeof(T);
constexpr int threads = 256;
const dim3 blocks((dst_max_len * head_dim / kVecSize + threads - 1) / threads, head_num, batch_size);
const auto smem_sz = (sizeof(void*) + sizeof(int)) * batch_size;
auto fn = [&](auto head_dim) {
BlocksToLinearKernel<<<blocks, threads, smem_sz, st>>>(src_block_ptrs,
dst,
src_cu_block_cnts,
seq_lens,
src_offset,
src_block_len,
dst_max_len,
head_num,
head_dim,
batch_size);
};
switch (head_dim) {
case 128:
fn(std::integral_constant<int, 128>{});
break;
default:
fn(head_dim);
}
}
template void ConvertBlocksToLinear(const half** src_block_ptrs,
half* dst,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_max_seq_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st);
template<typename T, typename SrcBlockLen, typename DstBlockLen, typename HeadDim>
__global__ void KvCacheBlocksToLinearKernel(const T** src_k_block_ptrs,
const T** src_v_block_ptrs,
T** dst_k_ptrs,
T** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
SrcBlockLen src_block_len,
DstBlockLen dst_block_len,
int head_num,
HeadDim head_dim,
int batch_size)
{
extern __shared__ int dst_cu_block_cnts[];
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
dst_cu_block_cnts[i] = i;
}
__syncthreads();
ConvertBlockSize(src_k_block_ptrs,
dst_k_ptrs,
src_cu_block_cnts,
dst_cu_block_cnts,
seq_lens,
src_offset,
0,
src_block_len,
dst_block_len,
head_dim);
ConvertBlockSize(src_v_block_ptrs,
dst_v_ptrs,
src_cu_block_cnts,
dst_cu_block_cnts,
seq_lens,
src_offset,
0,
src_block_len,
dst_block_len,
head_dim);
}
void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
void** dst_k_ptrs,
void** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int elem_bits,
cudaStream_t st)
{
auto fn = [&](auto value) {
using T = decltype(value);
constexpr int kVecSize = sizeof(uint4) / sizeof(T);
constexpr int kThreads = 256;
const dim3 blocks((dst_block_len * head_dim / kVecSize + kThreads - 1) / kThreads, head_num, batch_size);
const auto smem_sz = sizeof(int) * batch_size;
KvCacheBlocksToLinearKernel<<<blocks, kThreads, smem_sz, st>>>((const T**)src_k_block_ptrs,
(const T**)src_v_block_ptrs,
(T**)dst_k_ptrs,
(T**)dst_v_ptrs,
src_cu_block_cnts,
seq_lens,
src_offset,
src_block_len,
dst_block_len,
head_num,
head_dim,
batch_size);
};
switch (elem_bits) {
case 8:
fn(uint8_t{});
break;
case 16:
fn(uint16_t{});
break;
case 32:
fn(uint32_t{});
break;
default:
fprintf(stderr, "unsupported elem bits: %d\n", elem_bits);
}
}
template<typename Tin,
typename Tout,
typename SrcBlockLen,
typename DstBlockLen,
typename HeadDim,
typename TransformK,
typename TransformV>
__global__ void KvCacheBlocksToLinearKernel2(const Tin** src_k_block_ptrs,
const Tin** src_v_block_ptrs,
Tout** dst_k_ptrs,
Tout** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
SrcBlockLen src_block_len,
DstBlockLen dst_block_len,
int head_num,
HeadDim head_dim,
int batch_size,
TransformK transform_k,
TransformV transform_v)
{
extern __shared__ int dst_cu_block_cnts[];
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
dst_cu_block_cnts[i] = i;
}
__syncthreads();
ConvertBlockSize(src_k_block_ptrs,
dst_k_ptrs,
src_cu_block_cnts,
dst_cu_block_cnts,
seq_lens,
src_offset,
0,
src_block_len,
dst_block_len,
head_dim,
transform_k);
ConvertBlockSize(src_v_block_ptrs,
dst_v_ptrs,
src_cu_block_cnts,
dst_cu_block_cnts,
seq_lens,
src_offset,
0,
src_block_len,
dst_block_len,
head_dim,
transform_v);
}
template<typename T>
void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
T** dst_k_ptrs,
T** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int quant_policy,
const float* kv_params,
cudaStream_t st)
{
auto fn = [&](auto tin) {
using Tin = decltype(tin);
constexpr int kVecSize = sizeof(uint4) / sizeof(T);
constexpr int kThreads = 256;
const dim3 blocks((dst_block_len * head_dim / kVecSize + kThreads - 1) / kThreads, head_num, batch_size);
const auto smem_sz = sizeof(int) * batch_size;
KvCacheBlocksToLinearKernel2<<<blocks, kThreads, smem_sz, st>>>(
(const Tin**)src_k_block_ptrs,
(const Tin**)src_v_block_ptrs,
(T**)dst_k_ptrs,
(T**)dst_v_ptrs,
src_cu_block_cnts,
seq_lens,
src_offset,
src_block_len,
dst_block_len,
head_num,
head_dim,
batch_size,
ConvertKvCache<Tin, T>{kv_params[0], kv_params[1]},
ConvertKvCache<Tin, T>{kv_params[2], kv_params[3]});
};
(quant_policy & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
}
template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
float** dst_k_ptrs,
float** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int quant_policy,
const float* kv_params,
cudaStream_t st);
template void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
half** dst_k_ptrs,
half** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int quant_policy,
const float* kv_params,
cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cuda_runtime.h>
namespace turbomind {
template<typename T>
void ConvertLinearToBlocks(const T* src,
T** dst_block_ptrs,
const int* dst_cu_block_cnts,
const int* seq_lens,
int dst_offset,
int src_seq_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st);
template<typename T>
void ConvertBlocksToLinear(const T** src_block_ptrs,
T* dst,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_max_seq_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st);
void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
void** dst_k_ptrs,
void** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len, // max{seq_lens}
int head_num,
int head_dim,
int batch_size,
int elem_bits,
cudaStream_t st);
template<typename T>
void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
T** dst_k_ptrs,
T** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int quant_policy,
const float* kv_params,
cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "decoder_multihead_attention.h"
#include "kv_cache.h"
#include "test_utils.h"
#include <cmath>
#include <ios>
#include <iostream>
#include <thrust/universal_vector.h>
#include <algorithm>
#include <iomanip>
#include <numeric>
#include <random>
using namespace turbomind;
template<typename T>
T* align(T* ptr, size_t alignment)
{
size_t misalign = (uintptr_t)ptr % alignment;
std::cout << "misalignment: " << misalign << "\n";
if (misalign) {
return (T*)((uint8_t*)ptr + alignment - misalign);
}
return ptr;
}
// [S/S, H, S, D] <-> [S/b, H, b, D]
void TestBlocks(thrust::universal_vector<half>& linear, // linear data
thrust::universal_vector<half>& _blocks, // block data
thrust::universal_vector<half*>& _ptrs, // block ptrs
thrust::universal_vector<int>& _cu_block_cnts, // cumulative block counts
int head_num,
int head_dim,
int block_size,
int batch_size)
{
int seq_len = linear.size() / (head_dim * head_num * batch_size);
int n_blocks = (seq_len + block_size - 1) / block_size;
std::cout << "batch_size = " << batch_size << ", seq_len = " << seq_len << ", block_num = " << n_blocks
<< ", block_size = " << block_size << "\n";
thrust::universal_vector<half> blocks(batch_size * n_blocks * head_num * block_size * head_dim);
thrust::universal_vector<half*> ptrs(batch_size * n_blocks + 1); // +1 padding
std::vector<size_t> idxs(batch_size * n_blocks);
std::iota(idxs.begin(), idxs.end(), 0);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(idxs.begin(), idxs.end(), g);
for (int i = 0; i < idxs.size(); ++i) {
ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
}
thrust::universal_vector<int> seq_lens(batch_size);
thrust::fill(seq_lens.begin(), seq_lens.end(), seq_len);
std::vector<int> n_blocks_vec(batch_size + 1, n_blocks);
thrust::universal_vector<int> cu_block_cnts(batch_size + 1);
std::exclusive_scan(n_blocks_vec.begin(), n_blocks_vec.end(), cu_block_cnts.begin(), 0);
for (int i = 0; i < 10; ++i) {
ConvertLinearToBlocks((const half*)linear.data().get(),
ptrs.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
seq_len,
block_size,
head_num,
head_dim,
batch_size,
0);
}
thrust::universal_vector<half> _linear(linear.size());
for (int i = 0; i < 10; ++i) {
ConvertBlocksToLinear((const half**)ptrs.data().get(),
_linear.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
block_size,
seq_len,
head_num,
head_dim,
batch_size,
0);
}
cudaDeviceSynchronize();
if (0) {
std::cout << ">>> Compare\n";
Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
std::cout << "<<< Compare\n";
}
_blocks.swap(blocks);
_ptrs.swap(ptrs);
_cu_block_cnts.swap(cu_block_cnts);
}
int main(int argc, char* argv[])
{
DecoderMultiHeadAttentionParams<half> params{};
constexpr int kHeadNum = 32;
constexpr int kHeadDim = 128;
constexpr int KvHeadNum = 32;
constexpr int kBatchSize = 1;
// constexpr int kContextLen = 7306;
constexpr int kContextLen = 1024;
constexpr int kSequenceLen = kContextLen + 1;
constexpr int kBlockSz = 128;
constexpr int kTestIter = 10;
constexpr int kMaxSplitK = 1;
RNG rng{};
thrust::universal_vector<half> output(kBatchSize * kHeadNum * kHeadDim);
thrust::universal_vector<half> qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim);
thrust::universal_vector<bool> finished(kBatchSize);
thrust::universal_vector<half> k_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
thrust::universal_vector<half> v_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
thrust::universal_vector<int> sequence_lengths(kBatchSize);
thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
thrust::universal_vector<float> partial_M(kBatchSize * kHeadNum * kMaxSplitK);
thrust::universal_vector<float> partial_L(kBatchSize * kHeadNum * kMaxSplitK);
thrust::universal_vector<float> partial_O(kBatchSize * kHeadNum * kMaxSplitK * kHeadDim);
rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
if (kContextLen) {
rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim,
0,
sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum);
if constexpr (0) {
for (int b = 0; b < kBatchSize; ++b) {
for (int h = 0; h < KvHeadNum; ++h) {
for (int s = 0; s < kSequenceLen; ++s) {
for (int d = 0; d < kHeadDim; ++d) {
std::cout << std::setw(7) << std::setprecision(4) << std::fixed
<< (float)k_cache[b * KvHeadNum * kSequenceLen * kHeadDim
+ h * kSequenceLen * kHeadDim + s * kHeadDim + d]
<< " ";
}
std::cout << "\n";
}
std::cout << "\n";
}
std::cout << "\n";
}
std::exit(0);
}
cudaMemset2DAsync(v_cache.data().get() + kContextLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim,
0,
sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum);
}
thrust::universal_vector<half> k_blocks;
thrust::universal_vector<half*> k_ptrs;
thrust::universal_vector<int> cu_block_cnts;
TestBlocks(k_cache, k_blocks, k_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
thrust::universal_vector<half> v_blocks;
thrust::universal_vector<half*> v_ptrs;
TestBlocks(v_cache, v_blocks, v_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
thrust::universal_vector<half> k_cache_ref = k_cache;
thrust::universal_vector<half> v_cache_ref = v_cache;
thrust::universal_vector<half> output_ref = output;
thrust::universal_vector<void*> k_cache_ref_ptrs(kBatchSize);
thrust::universal_vector<void*> v_cache_ref_ptrs(kBatchSize);
cudaDeviceSynchronize();
for (int i = 0; i < kBatchSize; ++i) {
sequence_lengths[i] = kContextLen;
k_cache_ptrs[i] = k_cache.data().get() + i * k_cache.size() / kBatchSize;
v_cache_ptrs[i] = v_cache.data().get() + i * v_cache.size() / kBatchSize;
k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
// align(k_cache_ptrs[i], 256);
// align(v_cache_ptrs[i], 256);
}
// getchar();
params.out = output_ref.data().get();
params.q = qkv.data().get();
params.k = params.q + kHeadNum * kHeadDim;
params.v = params.k + KvHeadNum * kHeadDim;
params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;
params.batch_size = kBatchSize;
params.max_seq_len = kContextLen + 1;
params.cu_block_cnts = cu_block_cnts.data().get();
printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
params.k_cache_block_ptrs = (void**)k_ptrs.data().get();
params.v_cache_block_ptrs = (void**)v_ptrs.data().get();
params.kv_cache_block_size = kBlockSz;
params.finished = finished.data().get();
params.per_sample_length = sequence_lengths.data().get();
params.per_sample_k_cache = k_cache_ref_ptrs.data().get();
params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
params.layer_offset = 0;
params.num_heads = kHeadNum;
params.num_kv_heads = KvHeadNum;
params.size_per_head = kHeadDim;
params.inv_sqrt_dh = 1.f / std::sqrt((float)params.size_per_head);
params.rotary_embedding_dim = kHeadDim;
params.rotary_embedding_base = 10000.f;
params.partial_L = partial_L.data().get();
params.partial_M = partial_M.data().get();
params.partial_O = partial_O.data().get();
for (int i = 0; i < kTestIter; ++i) {
mmha_ft_reference(params, cudaStream_t{});
}
cudaDeviceSynchronize();
if (auto err = cudaGetLastError(); err != cudaSuccess) {
std::cout << cudaGetErrorString(err) << "\n";
return -1;
}
std::cout << "---------------------------------------------------\n";
params.out = output.data().get();
params.per_sample_k_cache = k_cache_ptrs.data().get();
params.per_sample_v_cache = v_cache_ptrs.data().get();
params.max_split_k = kMaxSplitK;
params.max_seq_len = kContextLen;
params.arch = 80;
std::vector<thrust::universal_vector<half>> outputs;
for (int i = 0; i < std::max(kTestIter, 1); ++i) {
DispatchDecoderMultiheadAttention<half>(params);
if (auto err = cudaGetLastError(); err != cudaSuccess) {
std::cout << cudaGetErrorString(err) << "\n";
return -1;
}
if (1) {
outputs.push_back(output);
}
}
thrust::universal_vector<int> seq_lens(kBatchSize);
for (auto& x : seq_lens) {
x = kContextLen + 1;
}
if (1) {
ConvertBlocksToLinear((const half**)k_ptrs.data().get(),
k_cache.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
kBlockSz,
kSequenceLen,
KvHeadNum,
kHeadDim,
kBatchSize,
0);
ConvertBlocksToLinear((const half**)v_ptrs.data().get(),
v_cache.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
kBlockSz,
kSequenceLen,
KvHeadNum,
kHeadDim,
kBatchSize,
0);
}
cudaDeviceSynchronize();
if (outputs.size() > 1) {
std::cout << "Evaluating consistency..." << std::endl;
for (size_t i = 1; i < outputs.size(); ++i) {
Compare(outputs[i].data().get(), outputs[0].data().get(), kHeadDim, kHeadDim, kHeadNum);
}
}
std::cout << "---------------------------------------------------\n";
Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadDim, kHeadNum, false);
// [H, S, D]
Compare(k_cache.data().get() + kContextLen * kHeadDim,
k_cache_ref.data().get() + kContextLen * kHeadDim,
kSequenceLen * kHeadDim,
kHeadDim,
KvHeadNum);
Compare(v_cache.data().get() + kContextLen * kHeadDim,
v_cache_ref.data().get() + kContextLen * kHeadDim,
kSequenceLen * kHeadDim,
kHeadDim,
KvHeadNum);
return 0;
}
// Copyright (c) OpenMMLab. All rights reserved.
#include "test_utils.h"
#include <cublas_v2.h>
#include <curand.h>
#include <curand_kernel.h>
#include <fstream>
#include <iostream>
#define _CG_ABI_EXPERIMENTAL
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cooperative_groups/reduce.h>
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
namespace turbomind {
cublasHandle_t cublas_handle{};
cudaStream_t cublas_stream{};
template<typename T>
void Compare(const T* src, const T* ref, size_t stride, int m, int n, bool show, float rtol, float atol)
{
float asums{};
float rsums{};
int outliers{};
for (int nn = 0; nn < n; ++nn) {
float abs_diff_sum{};
float rel_diff_sum{};
for (int mm = 0; mm < m; ++mm) {
auto x = float(src[nn * stride + mm]);
auto y = float(ref[nn * stride + mm]);
// if (show) {
// std::cout << x << "\t" << y << std::endl;
// }
auto abs_diff = std::abs(x - y);
auto rel_diff = abs_diff / std::abs(y + 1e-6f);
if (abs_diff > atol + rtol * std::abs(y)) {
++outliers;
if (show) {
std::cout << nn << "," << mm << "\t" << x << "\t" << y << std::endl;
}
}
abs_diff_sum += abs_diff;
rel_diff_sum += rel_diff;
}
asums += abs_diff_sum / m;
rsums += rel_diff_sum / m;
}
std::cout << "abs_diff = " << asums / n << " rel_diff = " << rsums / n << " outliers = " << outliers / (float)n
<< std::endl;
}
template void Compare(const half* src, const half* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
template void
Compare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
void LoadBinary(const std::string& path, size_t size, void* dst)
{
std::ifstream ifs(path, std::ios::binary | std::ios::in);
if (!ifs.is_open()) {
std::cerr << "failed to open " << path << "\n";
std::abort();
}
ifs.seekg(0, ifs.end);
auto actual_size_in_bytes = ifs.tellg();
ifs.seekg(0, ifs.beg);
if (size != actual_size_in_bytes) {
std::cerr << "[warning] file " << path << " has " << actual_size_in_bytes << " bytes, while " << size
<< " bytes is requested\n";
}
ifs.read((char*)dst, size);
std::cerr << "[info] " << path << " " << size << "\n";
}
namespace cg = cooperative_groups;
__global__ void curand_init(curandState* state)
{
auto tid = cg::this_grid().thread_rank();
curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);
}
template<typename T>
__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)
{
auto grid = cg::this_grid();
for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
float tmp = curand_uniform(state + grid.thread_rank());
result[i] = T(scale * tmp + shift);
}
}
template<typename T>
__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)
{
auto grid = cg::this_grid();
for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
float tmp = curand_normal(state + grid.thread_rank());
result[i] = T(scale * tmp + shift);
}
}
__global__ void curand_bytes(curandState* state, size_t count, uint* result)
{
auto grid = cg::this_grid();
for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
result[i] = curand(state + grid.thread_rank());
}
}
struct RNG::Impl {
curandState* states{};
Impl()
{
cudaMalloc(&states, sizeof(curandState) * 64 * 64);
curand_init<<<64, 64>>>(states);
}
~Impl()
{
cudaFree(states);
}
void GenerateUInt(uint* out, size_t count)
{
curand_bytes<<<64, 64>>>(states, count, out);
}
template<typename T>
void GenerateUniform(T* out, size_t count, float scale, float shift)
{
curand_uniform<<<64, 64>>>(states, count, out, scale, shift);
}
template<typename T>
void GenerateNormal(T* out, size_t count, float scale, float shift)
{
curand_normal<<<64, 64>>>(states, count, out, scale, shift);
}
};
RNG::RNG(): impl_(std::make_unique<Impl>()) {}
RNG::~RNG() = default;
void RNG::GenerateUInt(uint* out, size_t count)
{
impl_->GenerateUInt(out, count);
}
template<typename T>
void RNG::GenerateUniform(T* out, size_t count, float scale, float shift)
{
std::cout << count << std::endl;
impl_->GenerateUniform(out, count, scale, shift);
}
template<typename T>
void RNG::GenerateNormal(T* out, size_t count, float scale, float shift)
{
impl_->GenerateNormal(out, count, scale, shift);
}
template void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);
template void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);
template<typename T>
struct SATypeConverter {
using Type = T;
};
template<>
struct SATypeConverter<half> {
using Type = uint16_t;
};
template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t st)
{
using DataType = typename SATypeConverter<T>::Type;
// Prepare the parameters.
Masked_multihead_attention_params<DataType> params{};
params.q_bias = reinterpret_cast<const DataType*>(p.q_bias);
params.k_bias = reinterpret_cast<const DataType*>(p.k_bias);
params.v_bias = reinterpret_cast<const DataType*>(p.v_bias);
// Set the output buffer.
params.out = reinterpret_cast<DataType*>(p.out);
// Set the input buffers.
// [B, nH + kvH, D]
params.q = reinterpret_cast<const DataType*>(p.q);
params.k = reinterpret_cast<const DataType*>(p.k);
params.v = reinterpret_cast<const DataType*>(p.v);
params.stride = p.stride;
params.finished = (bool*)p.finished;
params.k_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_k_cache);
params.v_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_v_cache);
params.kv_cache_per_sample_offset = p.layer_offset;
params.batch_size = p.batch_size;
params.beam_width = 1;
params.memory_max_len = p.max_seq_len;
params.prefix_prompt_lengths = 0;
params.max_prefix_prompt_length = 0;
params.length_per_sample = p.per_sample_length; // max_input_length + current output length
for (int i = 0; i < p.batch_size; ++i) {
params.timestep = std::max(p.per_sample_length[i], params.timestep);
}
std::cout << "timestep = " << params.timestep << "\n";
params.num_heads = p.num_heads;
params.num_kv_heads = p.num_kv_heads;
params.hidden_size_per_head = p.size_per_head;
params.rotary_embedding_dim = p.rotary_embedding_dim;
params.max_position_embeddings = p.max_position_embeddings;
params.use_dynamic_ntk = false;
params.use_logn_attn = p.use_logn_attn;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * 1.f);
params.int8_mode = 0;
masked_multihead_attention(params, st);
}
template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params, cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "decoder_multihead_attention.h"
#include "src/turbomind/macro.h"
#include <cuda_fp16.h>
#include <memory>
namespace turbomind {
template<typename T>
void Compare(
const T* src, const T* ref, size_t stride, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);
void LoadBinary(const std::string& path, size_t size, void* dst);
class RNG {
public:
RNG();
~RNG();
void GenerateUInt(uint* out, size_t count);
template<typename T>
void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);
template<typename T>
void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params, cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "../gemm_s_f16/common.h"
namespace turbomind {
template<int C, int S, int AccessC, int WarpCount>
struct ThreadMapQ {
static constexpr int kWarpCount = WarpCount;
static constexpr int kAccessC = AccessC;
static constexpr int kWarpThreadC = C / kAccessC;
static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
static_assert(kWarpThreadC <= WARP_SIZE);
static constexpr int kWarpAccessC = kWarpThreadC * kAccessC; // C
static constexpr int kWarpAccessS = kWarpThreadS;
static constexpr int kWarpIterC = C / kWarpAccessC; // 1
static constexpr int kWarpIterS = S / kWarpAccessS;
static constexpr int kWarpC = 1;
static constexpr int kWarpS = kWarpCount;
static constexpr int kIterC = kWarpIterC / kWarpC; // 1
static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
static constexpr int kFootprintC = kWarpAccessC * kIterC; // C
static constexpr int kFootprintS = kWarpAccessS * kIterS;
static constexpr int kDeltaC = kWarpAccessC;
static constexpr int kDeltaS = kWarpAccessS;
__device__ static int2 get_offset(int warp_id, int lane_id)
{
int warp_offset_c = warp_id % kWarpC;
int warp_offset_s = warp_id / kWarpC;
int warp_thread_offset_c = lane_id % kWarpThreadC;
int warp_thread_offset_s = lane_id / kWarpThreadC;
int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
return {cta_thread_offset_c, cta_thread_offset_s};
}
};
template<int C, int S, int AccessC, int WarpThreadC, int WarpCount>
struct ThreadMapKv {
static constexpr int kC = C;
static constexpr int kS = S;
static constexpr int kWarpCount = WarpCount;
static constexpr int kAccessC = AccessC;
static constexpr int kWarpThreadC = WarpThreadC;
static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
static_assert(kWarpThreadC <= WARP_SIZE);
static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
static constexpr int kWarpAccessS = kWarpThreadS;
static constexpr int kWarpIterC = C / kWarpAccessC;
static constexpr int kWarpIterS = S / kWarpAccessS;
static constexpr int kWarpC = 1;
static constexpr int kWarpS = kWarpCount;
static constexpr int kIterC = kWarpIterC / kWarpC;
static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
static constexpr int kFootprintC = kWarpAccessC * kIterC;
static constexpr int kFootprintS = kWarpAccessS * kIterS;
static constexpr int kDeltaC = kWarpAccessC;
static constexpr int kDeltaS = kWarpAccessS;
__device__ static int2 get_offset(int warp_id, int lane_id)
{
int warp_offset_c = warp_id % kWarpC;
int warp_offset_s = warp_id / kWarpC;
int warp_thread_offset_c = lane_id % kWarpThreadC;
int warp_thread_offset_s = lane_id / kWarpThreadC;
int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
return {cta_thread_offset_c, cta_thread_offset_s};
}
};
} // namespace turbomind
......@@ -3,6 +3,7 @@
#pragma once
#include "common.h"
#include <cstddef>
#include <cstdint>
namespace turbomind {
......@@ -236,7 +237,13 @@ struct IteratorA {
__device__ void prefetch(bool mask)
{
#if TURBOMIND_ARCH_SM80
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
}
#endif
}
};
......@@ -417,7 +424,13 @@ struct IteratorQ {
__device__ void prefetch(bool mask)
{
#if TURBOMIND_ARCH_SM80
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
}
#endif
}
};
......@@ -613,8 +626,14 @@ struct IteratorB {
__device__ void prefetch(bool mask)
{
#if TURBOMIND_ARCH_SM80
cp_async_cg_B(
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
#else
if (is_valid_n_ && mask) {
*(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
}
#endif
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment