Unverified Commit ab1767cf authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

TurboMind 2 (#590)

* refresh decoder attention kernel

* block-level kv cache

* `BlockManager` & `SequenceManager`

* update

* update

* update

* update

* rename

* GQA support

* fix context length

* GQA dispatch

* kv8

* tune

* async stream cb

* nvtx

* config parsing

* debug

* optimize output cost

* split-k decoding

* minor

* truncate `session_len` by available blocks

* minor

* license

* fix

* dispatch `cp.async`

* fix linking

* fix

* fix deadlock

* guard input length

* correct start offset

* fix prefill chunking

* fix `cache_block_seq_len` param passing

* fix `block_size` fmtstr

* fix output tokens

* fix batch resizing

* fix masking of finished sequences

* add debug util

* free unused block early

* add ntk scaling and logn scaling

* cmake flags

* fix typo

* w4a16 for sm75

* fix msvc build

* fix msvc build

* fix block verification

* fix msvc build

* use `std::shuffle`

* fix lint

* fix lint

* fix lint

* clear incoming buffer

* clear finished requests

* fix batch initialization

* fix typo

* fix typo

* fix comparison
parent 06125966
...@@ -72,3 +72,5 @@ work_dir*/ ...@@ -72,3 +72,5 @@ work_dir*/
*.out *.out
*.csv *.csv
*.pkl *.pkl
!CMakeLists.txt
...@@ -61,6 +61,22 @@ option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF ...@@ -61,6 +61,22 @@ option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF
option(BUILD_FAST_MATH "Build in fast math mode" ON) option(BUILD_FAST_MATH "Build in fast math mode" ON)
# the environment variable
# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
# must be set at runtime
# https://github.com/google/sanitizers/issues/1322
if (LMDEPLOY_ASAN_ENABLE)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=address>)
add_link_options(-fsanitize=address)
endif ()
# notice that ubsan has linker issues for ubuntu < 18.04, see
# https://stackoverflow.com/questions/50024731/ld-unrecognized-option-push-state-no-as-needed
if (LMDEPLOY_UBSAN_ENABLE)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:-fsanitize=undefined>)
add_link_options(-fsanitize=undefined)
endif ()
if(BUILD_MULTI_GPU) if(BUILD_MULTI_GPU)
message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL") message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL")
add_definitions("-DBUILD_MULTI_GPU") add_definitions("-DBUILD_MULTI_GPU")
...@@ -181,11 +197,15 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") ...@@ -181,11 +197,15 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
# set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
if(BUILD_FAST_MATH) if(BUILD_FAST_MATH)
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}") set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} --use_fast_math")
message("Release build CUDA flags: ${CMAKE_CUDA_FLAGS_RELEASE}")
endif() endif()
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
...@@ -252,11 +272,15 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');" ...@@ -252,11 +272,15 @@ print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');"
OUTPUT_VARIABLE USE_CXX11_ABI) OUTPUT_VARIABLE USE_CXX11_ABI)
message("-- USE_CXX11_ABI=${USE_CXX11_ABI}") message("-- USE_CXX11_ABI=${USE_CXX11_ABI}")
if (USE_CXX11_ABI) if (USE_CXX11_ABI)
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1")
else() else()
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0")
...@@ -327,6 +351,7 @@ add_library(transformer-shared SHARED ...@@ -327,6 +351,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:cuda_utils> $<TARGET_OBJECTS:cuda_utils>
$<TARGET_OBJECTS:custom_ar_comm> $<TARGET_OBJECTS:custom_ar_comm>
$<TARGET_OBJECTS:custom_ar_kernels> $<TARGET_OBJECTS:custom_ar_kernels>
$<TARGET_OBJECTS:decoder_multihead_attention>
$<TARGET_OBJECTS:decoder_masked_multihead_attention> $<TARGET_OBJECTS:decoder_masked_multihead_attention>
$<TARGET_OBJECTS:decoding_kernels> $<TARGET_OBJECTS:decoding_kernels>
$<TARGET_OBJECTS:gpt_kernels> $<TARGET_OBJECTS:gpt_kernels>
......
...@@ -80,8 +80,10 @@ broadCastRequest(const std::vector<int>& v_start_ids, ...@@ -80,8 +80,10 @@ broadCastRequest(const std::vector<int>& v_start_ids,
if (node_id == 0) { if (node_id == 0) {
memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int));
memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int));
if (!v_input_bad_words.empty()) {
memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int));
} }
}
if (kUSE_MPI) { if (kUSE_MPI) {
ft::mpi::barrier(); ft::mpi::barrier();
} }
...@@ -431,6 +433,8 @@ int main(int argc, char* argv[]) ...@@ -431,6 +433,8 @@ int main(int argc, char* argv[])
const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1];
const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2];
ft::FT_CHECK(beam_width == 1);
std::vector<int> seq_lens(batch_size); std::vector<int> seq_lens(batch_size);
// step 6: check results // step 6: check results
if (node_id == 0) { if (node_id == 0) {
...@@ -440,32 +444,25 @@ int main(int argc, char* argv[]) ...@@ -440,32 +444,25 @@ int main(int argc, char* argv[])
printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); printf("[WARNING] Cannot write results into output file %s \n", fName.c_str());
} }
else { else {
size_t outCount = batch_size * beam_width * seq_len; const size_t outCount = batch_size * beam_width * seq_len;
// int* hBuf = new int[outCount];
std::vector<int> hBuf(outCount); std::vector<int> hBuf(outCount);
ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount); ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size); ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
std::cout << "sequence length: "; std::cout << "sequence length: ";
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
std::cout << (i ? ", " : "") << seq_lens[i]; std::cout << (i ? ", " : "") << seq_lens[i];
} }
std::cout << "\n"; std::cout << "\n";
{
std::cout << "Writing " << outCount << " elements\n";
int zeroCount = 0;
for (size_t i = 0; i < outCount; i++) {
if (hBuf[i] == int(0))
zeroCount++;
outFile << hBuf[i] << " ";
if ((i + 1) % (seq_len) == 0)
outFile << std::endl;
if (i < 10) for (int i = 0; i < batch_size; ++i) {
printf("%5d ", hBuf[i]); outFile << (i ? "\n" : "");
if ((i + 1) % (seq_len) == 0 && i < 10) auto buf = hBuf.data() + seq_len * i;
std::cout << std::endl; for (int j = 0; j < seq_lens[i]; ++j) {
outFile << buf[j] << " ";
} }
std::cout << std::endl << "zeroCount = " << zeroCount << std::endl;
} }
} }
} }
...@@ -475,7 +472,7 @@ int main(int argc, char* argv[]) ...@@ -475,7 +472,7 @@ int main(int argc, char* argv[])
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
if (1) { if (0) {
// test time // test time
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
......
...@@ -71,3 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) ...@@ -71,3 +71,4 @@ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_subdirectory(gemm_s_f16) add_subdirectory(gemm_s_f16)
add_subdirectory(decoder_multihead_attention)
...@@ -43,6 +43,12 @@ ...@@ -43,6 +43,12 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// cudaFuncAttributes attr{}; \
// cudaFuncGetAttributes(&attr, func); \
// std::cout << "static_smem_sz: " << attr.sharedSizeBytes << std::endl; \
// std::cout << "max_dynamic_smem: " << attr.maxDynamicSharedSizeBytes << std::endl; \
// std::cout << "dynamic_smem_sz: " << smem_sz << std::endl; \
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE> template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{ {
......
...@@ -1472,6 +1472,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1472,6 +1472,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
printf("QK_last[%d] = %f\n", hi, qk);
qk_max = qk; qk_max = qk;
qk_smem[tlength - first_step] = qk; qk_smem[tlength - first_step] = qk;
// qk_smem[params.timestep] = qk; // qk_smem[params.timestep] = qk;
...@@ -1596,6 +1598,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1596,6 +1598,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist); qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
} }
// printf("QK_%d = %f\n", (int)ti, qk);
qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
qk_smem[ti - first_step] = qk; qk_smem[ti - first_step] = qk;
} }
...@@ -1632,6 +1635,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1632,6 +1635,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Broadcast to all the threads in the warp. // Broadcast to all the threads in the warp.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
if (threadIdx.x == 0) {
printf("QK_MAX[%d] = %f\n", hi, (float)qk_max);
}
// Compute the logits and start the sum. // Compute the logits and start the sum.
float sum = 0.f; float sum = 0.f;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
...@@ -1657,6 +1664,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1657,6 +1664,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Compute the sum. // Compute the sum.
sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum); sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
if (threadIdx.x == 0) {
printf("SUM[%d] = %f\n", hi, (float)sum);
}
// Normalize the logits. // Normalize the logits.
float inv_sum = __fdividef(1.f, sum + 1.e-6f); float inv_sum = __fdividef(1.f, sum + 1.e-6f);
......
# Copyright (c) OpenMMLab. All rights reserved.
add_library(decoder_multihead_attention STATIC decoder_multihead_attention.cu kv_cache.cu)
# target_compile_options(decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr --keep)
set_property(TARGET decoder_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET decoder_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(decoder_multihead_attention PRIVATE nvidia::cutlass::cutlass)
add_executable(test_decoder_multihead_attention test_utils.cu test_decoder_multihead_attention.cu)
# target_compile_options(test_decoder_multihead_attention PRIVATE
# --generate-line-info -O3 -use_fast_math -Xptxas=-v --expt-relaxed-constexpr)
target_link_libraries(test_decoder_multihead_attention PRIVATE
decoder_multihead_attention
decoder_masked_multihead_attention
cublas)
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "src/turbomind/kernels/gemm_s_f16/common.h"
#include <cfloat>
#include <limits>
namespace turbomind {
namespace ops {
template<typename T>
struct plus {
__device__ T operator()(T a, T b)
{
return a + b;
}
};
template<typename T>
struct minus {
__device__ T operator()(T a, T b)
{
return a - b;
}
};
template<typename T>
struct multiplies {
__device__ T operator()(T a, T b)
{
return a * b;
}
};
template<typename T, int N, typename Op>
inline __device__ Array<T, N> binary_op_vv(const Array<T, N>& a, const Array<T, N>& b, Op op)
{
Array<T, N> c;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
c[i] = op(a[i], b[i]);
}
return c;
}
template<typename T, int N, typename Op>
inline __device__ Array<T, N> binary_op_sv(const T& a, const Array<T, N>& b, Op op)
{
Array<T, N> c;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
c[i] = op(a, b[i]);
}
return c;
}
template<typename T, int N, typename Op>
inline __device__ Array<T, N> binary_op_vs(const Array<T, N>& a, const T& b, Op op)
{
Array<T, N> c;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
c[i] = op(a[i], b);
}
return c;
}
template<typename T, int N>
inline __device__ Array<T, N> operator+(const Array<T, N>& a, const Array<T, N>& b)
{
return binary_op_vv(a, b, plus<T>{});
}
template<typename T, int N>
inline __device__ Array<T, N> operator*(const Array<T, N>& a, const Array<T, N>& b)
{
return binary_op_vv(a, b, multiplies<T>{});
}
template<typename T, int N>
inline __device__ Array<T, N> operator*(const Array<T, N>& a, const T& b)
{
return binary_op_vs(a, b, multiplies<T>{});
}
} // namespace ops
template<typename To, typename From, int N>
inline __device__ Array<To, N> cast(const Array<From, N>& src)
{
Array<To, N> dst;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
dst[i] = (To)src[i];
}
return dst;
}
template<int N>
struct RotaryEmbedding {
static_assert(N % 2 == 0);
Array<float, N> cs_;
__device__ RotaryEmbedding(float base, int dims, int timestep, int2 offset)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
const float2 tmp = get_coefficient(offset.x + i, dims, base, timestep);
cs_[i] = tmp.x;
cs_[i + 1] = tmp.y;
}
}
static __device__ inline float2 get_coefficient(int idx, int dims, float base, int timestep)
{
const float inv_freq = timestep / powf(base, idx / (float)dims);
float2 cs;
sincosf(inv_freq, &cs.y, &cs.x);
return cs;
}
template<typename T>
__device__ void apply(Array<T, N>& x)
{
PRAGMA_UNROLL
for (int i = 0; i < N; i += 2) {
float tmp0 = cs_[i] * (float)x[i] - cs_[i + 1] * (float)x[i + 1];
float tmp1 = cs_[i] * (float)x[i + 1] + cs_[i + 1] * (float)x[i];
x[i] = (T)tmp0;
x[i + 1] = (T)tmp1;
}
}
};
struct LogNScaling {
float scale_;
__device__ static float get_scale(int seq_len, int max_position_embeddings)
{
if (seq_len <= max_position_embeddings) {
return 1.f;
}
else {
return log2f(seq_len) / log2f(max_position_embeddings);
}
}
__device__ LogNScaling(int seq_len, int max_position_embeddings)
{
scale_ = get_scale(seq_len, max_position_embeddings);
}
template<typename T, int N>
__device__ void apply(Array<T, N>& x) const
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
x[i] = (T)((float)x[i] * scale_);
}
}
};
template<typename T, int N>
inline __device__ void Store(T* dst, const Array<T, N>& src)
{
static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
*(uint4*)dst = (const uint4&)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
*(uint2*)dst = (const uint2&)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint1)) {
*(uint1*)dst = (const uint1&)src;
}
else {
static_assert(!std::is_same_v<T, T>);
}
}
template<typename T, int N>
inline __device__ void Ldg(Array<T, N>& dst, const T* src)
{
static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
(uint4&)dst = __ldg((const uint4*)src);
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
(uint2&)dst = __ldg((const uint2*)src);
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {
(uint&)dst = __ldg((const uint*)src);
}
else {
static_assert(!std::is_same_v<T, T>);
}
}
template<typename T, int N>
inline __device__ void Lds(Array<T, N>& dst, const T* src)
{
static_assert(sizeof(Array<T, N>) <= sizeof(uint4));
if constexpr (sizeof(Array<T, N>) == sizeof(uint4)) {
(uint4&)dst = *(const uint4*)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint2)) {
(uint2&)dst = *(const uint2*)src;
}
else if constexpr (sizeof(Array<T, N>) == sizeof(uint)) {
(uint1&)dst = *(const uint1*)src;
}
else {
static_assert(!std::is_same_v<T, T>);
}
}
template<typename Accum, typename Compute, int kThreadGroupSize, typename Tq, typename Tk, int N, int V>
inline __device__ Accum qk_dot(const Array<Tq, N> (&q)[V], const Array<Tk, N> (&k)[V])
{
Accum accum{};
PRAGMA_UNROLL
for (int vi = 0; vi < V; ++vi) {
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
accum += Accum(Compute(q[vi][i]) * Compute(k[vi][i]));
}
}
PRAGMA_UNROLL
for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
}
return accum;
}
template<typename Accum, typename Compute, int kThreadGroupSize, typename Tq, typename Tk, int N>
inline __device__ Accum qk_dot(const Array<Tq, N>& q, const Array<Tk, N>& k)
{
Accum accum{};
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
accum += Accum(Compute(q[i]) * Compute(k[i]));
}
PRAGMA_UNROLL
for (int mask = kThreadGroupSize / 2; mask >= 1; mask /= 2) {
accum += __shfl_xor_sync((uint32_t)-1, accum, mask);
}
return accum;
}
template<typename ComputeType, typename Tp, typename Tv, typename To, int N, int M>
inline __device__ void fma_pv(Tp pr, const Array<Tv, N> (&v)[M], Array<To, N> (&o)[M])
{
PRAGMA_UNROLL
for (int m = 0; m < M; ++m) {
PRAGMA_UNROLL
for (int n = 0; n < N; ++n) {
o[m][n] += To(ComputeType(v[m][n]) * ComputeType(pr));
}
}
}
template<typename ThreadMap, typename T, int N>
inline __device__ Array<T, N> qk_max(Array<T, N> val, T* smem_red, int warp_id, int lane_id)
{
constexpr int kWarpCount = ThreadMap::kWarpCount;
// warp maximum
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
PRAGMA_UNROLL
for (int mask = WARP_SIZE / 2; mask >= ThreadMap::kWarpThreadC; mask /= 2) {
val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
}
if (lane_id == 0) {
smem_red[i * kWarpCount + warp_id] = val[i];
}
}
__syncthreads();
// block maximum
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : -FLT_MAX;
PRAGMA_UNROLL
for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
val[i] = fmaxf(val[i], __shfl_xor_sync((uint32_t)-1, val[i], mask));
}
// braodcast to all threads
val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
}
return val;
}
template<int kWarpCount, typename T, int N>
inline __device__ Array<T, N> blockSum(Array<T, N> val, T* smem_red, int warp_id, int lane_id)
{
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
PRAGMA_UNROLL
for (int mask = WARP_SIZE >> 1; mask >= 1; mask >>= 1) {
val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
}
if (lane_id == 0) {
smem_red[i * kWarpCount + warp_id] = val[i];
}
}
__syncthreads();
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
val[i] = lane_id < kWarpCount ? smem_red[i * kWarpCount + lane_id] : T{};
PRAGMA_UNROLL
for (int mask = kWarpCount >> 1; mask >= 1; mask >>= 1) {
val[i] += __shfl_xor_sync((uint32_t)-1, val[i], mask);
}
val[i] = __shfl_sync((uint32_t)-1, val[i], 0);
}
return val;
}
//////////////////////////////////////////////////////////////////////////////////////////////////
// generic case for floating point -> floating point / integer -> integer conversion
template<typename Ti, typename To, typename = void>
struct ConvertKvCache {
__device__ __host__ ConvertKvCache(float, float) {}
template<int N>
inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<To, N>
{
Array<To, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
vo[i] = (To)vi[i];
}
return vo;
}
};
// generic case for converting to same type, bypass
template<typename T>
struct ConvertKvCache<T, T> {
__device__ __host__ ConvertKvCache(float, float) {}
template<int N>
inline __device__ auto operator()(const Array<T, N>& v) const -> Array<T, N>
{
return v;
}
};
template<typename Ti>
struct ConvertKvCache<Ti, int8_t> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero) {}
inline __device__ uint8_t round(float x) const
{
uint32_t y;
asm("cvt.rni.sat.u8.f32 %0, %1;\n" : "=r"(y) : "f"(x));
return y;
}
template<int N>
inline __device__ auto operator()(const Array<Ti, N>& vi) const -> Array<int8_t, N>
{
Array<int8_t, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
// convert to unsigned int by offsetting +128
(uint8_t&)vo[i] = round(((float)vi[i] - zero_) / scale_ + 128.f);
}
return vo;
}
};
inline __device__ Array<float, 4> fast_i2f_f32_s8(const Array<int8_t, 4>& x)
{
union {
Array<float, 4> f32x4;
Array<uint32_t, 4> u32x4;
};
auto& i8s = (const uint32_t&)x;
// 00000000111111112222222233333333
// 01234567012345670123456701234567
// SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
// 0????????_______XXXXXXXX________
// (1 + x / 2^15) * 2^(e - 127) -> e - 127 == 15 -> e = 142
// 7 6 5 4
static constexpr uint32_t f32_magic = 0x47000000; // 2^15 = 32768
static constexpr uint32_t m0 = 0x7604;
static constexpr uint32_t m1 = 0x7614;
static constexpr uint32_t m2 = 0x7624;
static constexpr uint32_t m3 = 0x7634;
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[0]) : "r"(i8s), "n"(f32_magic), "n"(m0));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[1]) : "r"(i8s), "n"(f32_magic), "n"(m1));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[2]) : "r"(i8s), "n"(f32_magic), "n"(m2));
asm("prmt.b32 %0,%1,%2,%3;\n" : "=r"(u32x4[3]) : "r"(i8s), "n"(f32_magic), "n"(m3));
if (0) { // fused with dequantization
PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
f32x4[i] -= 32896.f; // 32768 + 128
}
}
return f32x4;
}
template<>
struct ConvertKvCache<int8_t, float> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
{
zero_ = zero_ - 32896.f * scale_;
}
template<int N>
inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<float, N>
{
Array<float, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 4) {
auto& vec = (Array<float, 4>&)vo[i];
vec = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
vec[j] = vec[j] * scale_ + zero_;
// vec[j] = vec[j] * scale_ + (zero_ - 32896.f * scale_);
}
}
return vo;
}
};
template<>
struct ConvertKvCache<int8_t, half> {
float scale_;
float zero_;
__device__ __host__ ConvertKvCache(float scale, float zero): scale_(scale), zero_(zero)
{
zero_ = zero_ - 32896.f * scale_;
}
template<int N>
inline __device__ auto operator()(const Array<int8_t, N>& vi) const -> Array<half, N>
{
Array<half, N> vo;
PRAGMA_UNROLL
for (int i = 0; i < N; i += 4) {
auto& vec = (Array<half, 4>&)vo[i];
auto tmp = fast_i2f_f32_s8((const Array<int8_t, 4>&)vi[i]);
PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
vec[j] = half(tmp[j] * scale_ + zero_);
// vec[j] = half(tmp[j] * scale_ + (zero_ - 32896.f * scale_));
}
}
return vo;
}
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "decoder_multihead_attention_template.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
#include <iostream>
namespace turbomind {
namespace {
template<typename MHAType>
bool Print(size_t dynamic_smem_size)
{
using MapKv = typename MHAType::MapKv;
std::cout << " warps: " << MapKv::kWarpCount << "\n";
std::cout << " shape: (" << MapKv::kC << ", " << MapKv::kS << ")\n";
std::cout << " access: (" << MapKv::kAccessC << ", " << 1 << ")\n";
std::cout << "warpThread: (" << MapKv::kWarpThreadC << ", " << MapKv::kWarpThreadS << ")\n";
std::cout << "warpAccess: (" << MapKv::kWarpAccessC << ", " << MapKv::kWarpAccessS << ")\n";
std::cout << " warpIter: (" << MapKv::kWarpIterC << ", " << MapKv::kWarpIterS << ")\n";
std::cout << " warp: (" << MapKv::kWarpC << ", " << MapKv::kWarpS << ")\n";
std::cout << " iter: (" << MapKv::kIterC << ", " << MapKv::kIterS << ")\n";
std::cout << " footprint: (" << MapKv::kFootprintC << ", " << MapKv::kFootprintS << ")\n";
std::cout << " delta: (" << MapKv::kDeltaC << ", " << MapKv::kDeltaS << ")\n";
std::cout << "dynamic smem size: " << dynamic_smem_size << "\n";
return true;
}
} // namespace
template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
{
auto invoke = [&](auto* type) {
using Attn = std::remove_reference_t<decltype(*type)>;
static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();
// [[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);
const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
dim3 block(Attn::kWarpCount * WARP_SIZE);
dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
// if (params.layer_offset == 0) {
// std::cout << "max_split_k' = " << max_split_k << ", arch = " << params.arch << "\n";
// }
cudaFuncSetAttribute(
decoder_multihead_attention<Attn>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
decoder_multihead_attention<Attn><<<grid, block, kDynSmemSize, params.stream>>>(params);
if (max_split_k > 1) {
dim3 grid(params.num_heads, params.batch_size);
decoder_multihead_attention_reduce<Attn><<<grid, block, 0, params.stream>>>(params);
}
};
if (params.arch >= 80) {
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>; // 64k
using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
invoke((Type*)0);
}
else {
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>; // 34k
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 2048, 3>; // 34k
using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 1024, 3, true>;
invoke((Type*)0);
}
}
template<typename T>
void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
{
static constexpr int HeadDim = 128;
FT_CHECK(params.size_per_head == HeadDim);
if constexpr (std::is_same_v<T, half>) {
if (params.quant_policy & QuantPolicy::kCacheKVInt8) {
invokeDecoderMultiheadAttention<T, int8_t, HeadDim, 1>(params);
return;
}
int group_size = params.num_heads / params.num_kv_heads;
if (0) {}
// else if (group_size % 8 == 0) {
// invokeDecoderMultiheadAttention<T, T, HeadDim, 8>(params);
// }
else if (group_size % 4 == 0) {
invokeDecoderMultiheadAttention<T, T, HeadDim, 4>(params);
}
else if (group_size % 2 == 0) {
invokeDecoderMultiheadAttention<T, T, HeadDim, 2>(params);
}
else {
invokeDecoderMultiheadAttention<T, T, HeadDim, 1>(params);
}
}
}
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<half>& params);
template void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<float>& params);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "decoder_multihead_attention_params.h"
namespace turbomind {
template<typename T>
void DispatchDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params);
}
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cuda_runtime.h>
namespace turbomind {
template<typename T>
struct DecoderMultiHeadAttentionParams {
// token-level buffers, [B, qH + 2kvH, D] or [B, kvH, D]
T* __restrict__ out;
T* __restrict__ q;
T* __restrict__ k;
T* __restrict__ v;
int stride;
// bias, [qH, D] or [kvH, D]
T* __restrict__ q_bias;
T* __restrict__ k_bias;
T* __restrict__ v_bias;
// sequence-level buffers
const int* __restrict__ per_sample_length;
const bool* __restrict__ finished;
const float* __restrict__ rope_theta;
// kv cache
void** __restrict__ per_sample_k_cache; // [H, S, D]
void** __restrict__ per_sample_v_cache; // [H, S, D]
size_t layer_offset;
/// cache layout M,[N,H,x,D]
/// S: [s0/x, s1/x, s2/x, ..., sn-1/x], si <- block
/// 1. [L,sum(S),H,x,D]
void** __restrict__ k_cache_block_ptrs; // X,[H,x,D]
void** __restrict__ v_cache_block_ptrs; // X,[H,x,D]
int* __restrict__ cu_block_cnts; // [B+1]
int kv_cache_block_size;
// batch-level params
int batch_size;
int max_seq_len;
// instance-level params
int num_heads;
int num_kv_heads;
int size_per_head;
float inv_sqrt_dh;
// rotary embedding
int rotary_embedding_dim;
float rotary_embedding_base;
int max_position_embeddings;
// bool use_dynamic_ntk;
// log(n) attention
bool use_logn_attn;
int quant_policy;
float kv_quant_params[4];
int max_split_k;
float* partial_O;
float* partial_M;
float* partial_L;
int arch;
cudaStream_t stream;
};
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "../gemm_s_f16/common.h"
#include "array_ops.h"
namespace turbomind {
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
struct BlockIterator {
const void** ptrs_;
const void* prefetch_;
BlockIterator() = default;
__device__ BlockIterator(const void** block_ptrs): ptrs_{block_ptrs}
{
// prefetch first ptr
prefetch_ = *ptrs_++;
}
__device__ const void* Next()
{
// return prefetched ptr
const void* ret = prefetch_;
// prefetch next ptr
prefetch_ = *ptrs_++;
return ret;
}
};
template<typename T, typename ThreadMap, int BlockLen, int Stages, bool kUseBlockIter>
struct Iterator {
using ElementType = T;
using AccessType = Array<T, ThreadMap::kAccessC>;
static constexpr int kElementSize = sizeof(ElementType);
static constexpr int kAccessSize = sizeof(AccessType);
static constexpr int kSizePerTile = ThreadMap::kS * ThreadMap::kC;
static constexpr int kSmemByteSize = kElementSize * Stages * kSizePerTile;
BlockIterator block_iterator_;
static constexpr int kIterCount = ThreadMap::kIterS * ThreadMap::kIterC;
static constexpr int kStepC = ThreadMap::kDeltaC;
static constexpr int kStepS = ThreadMap::kDeltaS * ThreadMap::kC - ThreadMap::kIterC * kStepC;
static constexpr int kStepK =
ThreadMap::kS * ThreadMap::kC - ThreadMap::kIterS * ThreadMap::kDeltaS * ThreadMap::kC;
// (C, S, K) = (64, 384, 1536)
// initial offset, used to reset src_offset when switching to a new block
int init_offset_;
int src_offset_;
int dst_offset_;
int iter_c_;
int iter_b_;
int seq_len_;
int offset_s_;
bool is_valid_s_;
int block_size_;
int block_k_;
int layer_offset_;
int head_idx_;
const T* __restrict__ src_;
T* __restrict__ smem_;
int smem_read_offset_;
struct __align__(sizeof(AccessType)) SharedStorage
{
T smem_[Stages][kSizePerTile];
};
Iterator() = default;
__device__ Iterator(T* src, T* smem, int step, int seq_len, int warp_id, int lane_id)
{
src_ = src;
smem_ = smem;
int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
src_offset_ = init_offset_ + step * ThreadMap::kC;
dst_offset_ = init_offset_;
smem_read_offset_ = init_offset_;
iter_c_ = 0;
iter_b_ = 0;
seq_len_ = seq_len;
offset_s_ = init_offset_cs.y + step;
is_valid_s_ = offset_s_ < seq_len;
}
__device__ Iterator(const void** block_ptrs,
int block_size,
int layer_offset,
int head_idx,
T* smem,
int step,
int seqlen,
int warp_id,
int lane_id)
{
// src_ = src;
int block_index = step / block_size;
block_size_ = block_size;
block_k_ = (block_index + 1) * block_size - step; // offset to next block
layer_offset_ = layer_offset;
head_idx_ = head_idx;
block_iterator_ = BlockIterator(block_ptrs + block_index);
src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
smem_ = smem;
int2 init_offset_cs = ThreadMap::get_offset(warp_id, lane_id);
init_offset_ = init_offset_cs.x + init_offset_cs.y * ThreadMap::kC;
src_offset_ = init_offset_ + (step - block_index * block_size) * ThreadMap::kC;
dst_offset_ = init_offset_;
smem_read_offset_ = init_offset_;
iter_c_ = 0;
iter_b_ = 0;
seq_len_ = seqlen;
offset_s_ = init_offset_cs.y + step;
is_valid_s_ = offset_s_ < seqlen;
}
__device__ void PrefetchStage()
{
PRAGMA_UNROLL
for (int i = 0; i < kIterCount; ++i) {
Prefetch(is_valid_s_);
++(*this);
}
AdvancePrefetchStage();
}
__device__ void PrefetchBatch(int batch_idx, int batch_size)
{
PRAGMA_UNROLL
for (int i = 0; i < batch_size; ++i) {
if (batch_idx * batch_size + i < kIterCount) {
Prefetch(is_valid_s_);
++(*this);
}
}
}
__device__ Iterator& operator++()
{
src_offset_ += kStepC;
dst_offset_ += kStepC;
++iter_c_;
if (iter_c_ < ThreadMap::kIterC) {
return *this;
}
iter_c_ = 0;
src_offset_ += kStepS;
dst_offset_ += kStepS;
offset_s_ += ThreadMap::kDeltaS;
is_valid_s_ = offset_s_ < seq_len_;
return *this;
}
__device__ void AdvancePrefetchStage()
{
src_offset_ += kStepK;
dst_offset_ += kStepK;
offset_s_ += ThreadMap::kS - ThreadMap::kIterS * ThreadMap::kDeltaS;
is_valid_s_ = offset_s_ < seq_len_;
if constexpr (kUseBlockIter) {
if (is_valid_s_) {
block_k_ -= ThreadMap::kS;
if (block_k_ == 0) {
src_ = (const T*)block_iterator_.Next() + layer_offset_ + head_idx_ * block_size_ * ThreadMap::kC;
block_k_ = block_size_;
src_offset_ = init_offset_;
}
}
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// printf("%d %d %d\n", offset_s_, src_offset_ / ThreadMap::kC, block_k_);
// }
}
// if (init_offset_ / ThreadMap::kC == 0) {
// int k = dst_offset_ / (ThreadMap::kS * ThreadMap::kC);
// int s = dst_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
// int c = dst_offset_ % ThreadMap::kC;
// printf("tid=%d, k=%d, s=%d, c=%d, offset_s=%d, valid_s=%d, init_s=%d\n",
// threadIdx.x,
// k,
// s,
// c,
// offset_s_,
// (int)is_valid_s_,
// init_offset_ / ThreadMap::kC);
// }
// if (threadIdx.x == 0 && blockIdx.x == 0) {
// printf("next stage %d\n", offset_s_);
// }
if (dst_offset_ >= Stages * kSizePerTile) {
dst_offset_ -= Stages * kSizePerTile;
}
// if constexpr (Chained) {
// bool is_last_stage = *signal_iterator_;
// ++signal_iterator_;
// if (is_last_stage) {
// AdvancePrefetchSlice();
// }
// }
}
#if 0
__device__ void AdvancePrefetchSlice()
{
src_ = (const T*)block_iterator_.Next();
src_offset_ = init_offset_;
++iter_b_;
offset_s_ = iter_b_ / 2 * BlockLen + init_offset_ / ThreadMap::kC;
is_valid_s_ = offset_s_ < seq_len_;
}
#endif
static __device__ void CpAsync(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{
const int smem_int_ptr = cast_smem_ptr_to_uint(dst);
constexpr int cp_size = sizeof(AccessType);
#if TURBOMIND_ARCH_SM80
// clang-format off
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
// clang-format on
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}
static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
{
if (mask) {
Ldg(*(AccessType*)dst, src);
}
}
__device__ void Prefetch(bool mask)
{
if constexpr (TURBOMIND_ARCH_SM80) {
CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
}
else {
Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
}
}
__device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
{
// if (init_offset_ / ThreadMap::kC == 0) {
// int k = smem_read_offset_ / (ThreadMap::kS * ThreadMap::kC);
// int s = smem_read_offset_ % (ThreadMap::kS * ThreadMap::kC) / ThreadMap::kC;
// int c = smem_read_offset_ % ThreadMap::kC;
// printf("tid=%d, k=%d, s=%d, c=%d, init_s=%d\n", threadIdx.x, k, s, c, init_offset_ / ThreadMap::kC);
// }
for (int vi = 0; vi < ThreadMap::kIterC; ++vi) {
// int offset = smem_read_offset_ + vi * ThreadMap::kDeltaC;
// if (offset >= Stages * kSizePerTile || offset % sizeof(AccessType)) {
// int c = offset % ThreadMap::kC;
// int s = offset / ThreadMap::kC;
// printf("%d %d %d\n", c, s, offset);
// }
Lds(frag[vi], smem_ + smem_read_offset_ + vi * ThreadMap::kDeltaC);
}
smem_read_offset_ += ThreadMap::kDeltaS * ThreadMap::kC;
}
__device__ void AdvanceComputeStage()
{
smem_read_offset_ += kStepK;
if (smem_read_offset_ >= Stages * kSizePerTile) {
smem_read_offset_ -= Stages * kSizePerTile;
}
}
};
} // namespace turbomind
This diff is collapsed.
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <cuda_runtime.h>
namespace turbomind {
template<typename T>
void ConvertLinearToBlocks(const T* src,
T** dst_block_ptrs,
const int* dst_cu_block_cnts,
const int* seq_lens,
int dst_offset,
int src_seq_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st);
template<typename T>
void ConvertBlocksToLinear(const T** src_block_ptrs,
T* dst,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_max_seq_len,
int head_num,
int head_dim,
int batch_size,
cudaStream_t st);
void ConvertKvCacheBlocksToLinear(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
void** dst_k_ptrs,
void** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len, // max{seq_lens}
int head_num,
int head_dim,
int batch_size,
int elem_bits,
cudaStream_t st);
template<typename T>
void ConvertKvCacheBlocksToLinear2(const void** src_k_block_ptrs,
const void** src_v_block_ptrs,
T** dst_k_ptrs,
T** dst_v_ptrs,
const int* src_cu_block_cnts,
const int* seq_lens,
int src_offset,
int src_block_len,
int dst_block_len,
int head_num,
int head_dim,
int batch_size,
int quant_policy,
const float* kv_params,
cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#include "decoder_multihead_attention.h"
#include "kv_cache.h"
#include "test_utils.h"
#include <cmath>
#include <ios>
#include <iostream>
#include <thrust/universal_vector.h>
#include <algorithm>
#include <iomanip>
#include <numeric>
#include <random>
using namespace turbomind;
template<typename T>
T* align(T* ptr, size_t alignment)
{
size_t misalign = (uintptr_t)ptr % alignment;
std::cout << "misalignment: " << misalign << "\n";
if (misalign) {
return (T*)((uint8_t*)ptr + alignment - misalign);
}
return ptr;
}
// [S/S, H, S, D] <-> [S/b, H, b, D]
void TestBlocks(thrust::universal_vector<half>& linear, // linear data
thrust::universal_vector<half>& _blocks, // block data
thrust::universal_vector<half*>& _ptrs, // block ptrs
thrust::universal_vector<int>& _cu_block_cnts, // cumulative block counts
int head_num,
int head_dim,
int block_size,
int batch_size)
{
int seq_len = linear.size() / (head_dim * head_num * batch_size);
int n_blocks = (seq_len + block_size - 1) / block_size;
std::cout << "batch_size = " << batch_size << ", seq_len = " << seq_len << ", block_num = " << n_blocks
<< ", block_size = " << block_size << "\n";
thrust::universal_vector<half> blocks(batch_size * n_blocks * head_num * block_size * head_dim);
thrust::universal_vector<half*> ptrs(batch_size * n_blocks + 1); // +1 padding
std::vector<size_t> idxs(batch_size * n_blocks);
std::iota(idxs.begin(), idxs.end(), 0);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(idxs.begin(), idxs.end(), g);
for (int i = 0; i < idxs.size(); ++i) {
ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
}
thrust::universal_vector<int> seq_lens(batch_size);
thrust::fill(seq_lens.begin(), seq_lens.end(), seq_len);
std::vector<int> n_blocks_vec(batch_size + 1, n_blocks);
thrust::universal_vector<int> cu_block_cnts(batch_size + 1);
std::exclusive_scan(n_blocks_vec.begin(), n_blocks_vec.end(), cu_block_cnts.begin(), 0);
for (int i = 0; i < 10; ++i) {
ConvertLinearToBlocks((const half*)linear.data().get(),
ptrs.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
seq_len,
block_size,
head_num,
head_dim,
batch_size,
0);
}
thrust::universal_vector<half> _linear(linear.size());
for (int i = 0; i < 10; ++i) {
ConvertBlocksToLinear((const half**)ptrs.data().get(),
_linear.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
block_size,
seq_len,
head_num,
head_dim,
batch_size,
0);
}
cudaDeviceSynchronize();
if (0) {
std::cout << ">>> Compare\n";
Compare(_linear.data().get(), linear.data().get(), head_dim, head_dim, batch_size * head_num * seq_len);
std::cout << "<<< Compare\n";
}
_blocks.swap(blocks);
_ptrs.swap(ptrs);
_cu_block_cnts.swap(cu_block_cnts);
}
int main(int argc, char* argv[])
{
DecoderMultiHeadAttentionParams<half> params{};
constexpr int kHeadNum = 32;
constexpr int kHeadDim = 128;
constexpr int KvHeadNum = 32;
constexpr int kBatchSize = 1;
// constexpr int kContextLen = 7306;
constexpr int kContextLen = 1024;
constexpr int kSequenceLen = kContextLen + 1;
constexpr int kBlockSz = 128;
constexpr int kTestIter = 10;
constexpr int kMaxSplitK = 1;
RNG rng{};
thrust::universal_vector<half> output(kBatchSize * kHeadNum * kHeadDim);
thrust::universal_vector<half> qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim);
thrust::universal_vector<bool> finished(kBatchSize);
thrust::universal_vector<half> k_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
thrust::universal_vector<half> v_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
thrust::universal_vector<int> sequence_lengths(kBatchSize);
thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);
thrust::universal_vector<float> partial_M(kBatchSize * kHeadNum * kMaxSplitK);
thrust::universal_vector<float> partial_L(kBatchSize * kHeadNum * kMaxSplitK);
thrust::universal_vector<float> partial_O(kBatchSize * kHeadNum * kMaxSplitK * kHeadDim);
rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);
if (kContextLen) {
rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim,
0,
sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum);
if constexpr (0) {
for (int b = 0; b < kBatchSize; ++b) {
for (int h = 0; h < KvHeadNum; ++h) {
for (int s = 0; s < kSequenceLen; ++s) {
for (int d = 0; d < kHeadDim; ++d) {
std::cout << std::setw(7) << std::setprecision(4) << std::fixed
<< (float)k_cache[b * KvHeadNum * kSequenceLen * kHeadDim
+ h * kSequenceLen * kHeadDim + s * kHeadDim + d]
<< " ";
}
std::cout << "\n";
}
std::cout << "\n";
}
std::cout << "\n";
}
std::exit(0);
}
cudaMemset2DAsync(v_cache.data().get() + kContextLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim,
0,
sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum);
}
thrust::universal_vector<half> k_blocks;
thrust::universal_vector<half*> k_ptrs;
thrust::universal_vector<int> cu_block_cnts;
TestBlocks(k_cache, k_blocks, k_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
thrust::universal_vector<half> v_blocks;
thrust::universal_vector<half*> v_ptrs;
TestBlocks(v_cache, v_blocks, v_ptrs, cu_block_cnts, KvHeadNum, kHeadDim, kBlockSz, kBatchSize);
thrust::universal_vector<half> k_cache_ref = k_cache;
thrust::universal_vector<half> v_cache_ref = v_cache;
thrust::universal_vector<half> output_ref = output;
thrust::universal_vector<void*> k_cache_ref_ptrs(kBatchSize);
thrust::universal_vector<void*> v_cache_ref_ptrs(kBatchSize);
cudaDeviceSynchronize();
for (int i = 0; i < kBatchSize; ++i) {
sequence_lengths[i] = kContextLen;
k_cache_ptrs[i] = k_cache.data().get() + i * k_cache.size() / kBatchSize;
v_cache_ptrs[i] = v_cache.data().get() + i * v_cache.size() / kBatchSize;
k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
v_cache_ref_ptrs[i] = v_cache_ref.data().get() + i * v_cache_ref.size() / kBatchSize;
// align(k_cache_ptrs[i], 256);
// align(v_cache_ptrs[i], 256);
}
// getchar();
params.out = output_ref.data().get();
params.q = qkv.data().get();
params.k = params.q + kHeadNum * kHeadDim;
params.v = params.k + KvHeadNum * kHeadDim;
params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;
params.batch_size = kBatchSize;
params.max_seq_len = kContextLen + 1;
params.cu_block_cnts = cu_block_cnts.data().get();
printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
params.k_cache_block_ptrs = (void**)k_ptrs.data().get();
params.v_cache_block_ptrs = (void**)v_ptrs.data().get();
params.kv_cache_block_size = kBlockSz;
params.finished = finished.data().get();
params.per_sample_length = sequence_lengths.data().get();
params.per_sample_k_cache = k_cache_ref_ptrs.data().get();
params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
params.layer_offset = 0;
params.num_heads = kHeadNum;
params.num_kv_heads = KvHeadNum;
params.size_per_head = kHeadDim;
params.inv_sqrt_dh = 1.f / std::sqrt((float)params.size_per_head);
params.rotary_embedding_dim = kHeadDim;
params.rotary_embedding_base = 10000.f;
params.partial_L = partial_L.data().get();
params.partial_M = partial_M.data().get();
params.partial_O = partial_O.data().get();
for (int i = 0; i < kTestIter; ++i) {
mmha_ft_reference(params, cudaStream_t{});
}
cudaDeviceSynchronize();
if (auto err = cudaGetLastError(); err != cudaSuccess) {
std::cout << cudaGetErrorString(err) << "\n";
return -1;
}
std::cout << "---------------------------------------------------\n";
params.out = output.data().get();
params.per_sample_k_cache = k_cache_ptrs.data().get();
params.per_sample_v_cache = v_cache_ptrs.data().get();
params.max_split_k = kMaxSplitK;
params.max_seq_len = kContextLen;
params.arch = 80;
std::vector<thrust::universal_vector<half>> outputs;
for (int i = 0; i < std::max(kTestIter, 1); ++i) {
DispatchDecoderMultiheadAttention<half>(params);
if (auto err = cudaGetLastError(); err != cudaSuccess) {
std::cout << cudaGetErrorString(err) << "\n";
return -1;
}
if (1) {
outputs.push_back(output);
}
}
thrust::universal_vector<int> seq_lens(kBatchSize);
for (auto& x : seq_lens) {
x = kContextLen + 1;
}
if (1) {
ConvertBlocksToLinear((const half**)k_ptrs.data().get(),
k_cache.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
kBlockSz,
kSequenceLen,
KvHeadNum,
kHeadDim,
kBatchSize,
0);
ConvertBlocksToLinear((const half**)v_ptrs.data().get(),
v_cache.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
0,
kBlockSz,
kSequenceLen,
KvHeadNum,
kHeadDim,
kBatchSize,
0);
}
cudaDeviceSynchronize();
if (outputs.size() > 1) {
std::cout << "Evaluating consistency..." << std::endl;
for (size_t i = 1; i < outputs.size(); ++i) {
Compare(outputs[i].data().get(), outputs[0].data().get(), kHeadDim, kHeadDim, kHeadNum);
}
}
std::cout << "---------------------------------------------------\n";
Compare(output.data().get(), output_ref.data().get(), kHeadDim, kHeadDim, kHeadNum, false);
// [H, S, D]
Compare(k_cache.data().get() + kContextLen * kHeadDim,
k_cache_ref.data().get() + kContextLen * kHeadDim,
kSequenceLen * kHeadDim,
kHeadDim,
KvHeadNum);
Compare(v_cache.data().get() + kContextLen * kHeadDim,
v_cache_ref.data().get() + kContextLen * kHeadDim,
kSequenceLen * kHeadDim,
kHeadDim,
KvHeadNum);
return 0;
}
// Copyright (c) OpenMMLab. All rights reserved.
#include "test_utils.h"
#include <cublas_v2.h>
#include <curand.h>
#include <curand_kernel.h>
#include <fstream>
#include <iostream>
#define _CG_ABI_EXPERIMENTAL
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
#include <cooperative_groups/reduce.h>
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
namespace turbomind {
cublasHandle_t cublas_handle{};
cudaStream_t cublas_stream{};
template<typename T>
void Compare(const T* src, const T* ref, size_t stride, int m, int n, bool show, float rtol, float atol)
{
float asums{};
float rsums{};
int outliers{};
for (int nn = 0; nn < n; ++nn) {
float abs_diff_sum{};
float rel_diff_sum{};
for (int mm = 0; mm < m; ++mm) {
auto x = float(src[nn * stride + mm]);
auto y = float(ref[nn * stride + mm]);
// if (show) {
// std::cout << x << "\t" << y << std::endl;
// }
auto abs_diff = std::abs(x - y);
auto rel_diff = abs_diff / std::abs(y + 1e-6f);
if (abs_diff > atol + rtol * std::abs(y)) {
++outliers;
if (show) {
std::cout << nn << "," << mm << "\t" << x << "\t" << y << std::endl;
}
}
abs_diff_sum += abs_diff;
rel_diff_sum += rel_diff;
}
asums += abs_diff_sum / m;
rsums += rel_diff_sum / m;
}
std::cout << "abs_diff = " << asums / n << " rel_diff = " << rsums / n << " outliers = " << outliers / (float)n
<< std::endl;
}
template void Compare(const half* src, const half* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
template void
Compare(const float* src, const float* ref, size_t stride, int m, int n, bool show, float rtol, float atol);
void LoadBinary(const std::string& path, size_t size, void* dst)
{
std::ifstream ifs(path, std::ios::binary | std::ios::in);
if (!ifs.is_open()) {
std::cerr << "failed to open " << path << "\n";
std::abort();
}
ifs.seekg(0, ifs.end);
auto actual_size_in_bytes = ifs.tellg();
ifs.seekg(0, ifs.beg);
if (size != actual_size_in_bytes) {
std::cerr << "[warning] file " << path << " has " << actual_size_in_bytes << " bytes, while " << size
<< " bytes is requested\n";
}
ifs.read((char*)dst, size);
std::cerr << "[info] " << path << " " << size << "\n";
}
namespace cg = cooperative_groups;
__global__ void curand_init(curandState* state)
{
auto tid = cg::this_grid().thread_rank();
curand_init(0xe4c45822e90461ddULL, tid, 0, state + tid);
}
template<typename T>
__global__ void curand_uniform(curandState* state, size_t count, T* result, float scale, float shift)
{
auto grid = cg::this_grid();
for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
float tmp = curand_uniform(state + grid.thread_rank());
result[i] = T(scale * tmp + shift);
}
}
template<typename T>
__global__ void curand_normal(curandState* state, size_t count, T* result, float scale, float shift)
{
auto grid = cg::this_grid();
for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
float tmp = curand_normal(state + grid.thread_rank());
result[i] = T(scale * tmp + shift);
}
}
__global__ void curand_bytes(curandState* state, size_t count, uint* result)
{
auto grid = cg::this_grid();
for (auto i = grid.thread_rank(); i < count; i += grid.size()) {
result[i] = curand(state + grid.thread_rank());
}
}
struct RNG::Impl {
curandState* states{};
Impl()
{
cudaMalloc(&states, sizeof(curandState) * 64 * 64);
curand_init<<<64, 64>>>(states);
}
~Impl()
{
cudaFree(states);
}
void GenerateUInt(uint* out, size_t count)
{
curand_bytes<<<64, 64>>>(states, count, out);
}
template<typename T>
void GenerateUniform(T* out, size_t count, float scale, float shift)
{
curand_uniform<<<64, 64>>>(states, count, out, scale, shift);
}
template<typename T>
void GenerateNormal(T* out, size_t count, float scale, float shift)
{
curand_normal<<<64, 64>>>(states, count, out, scale, shift);
}
};
RNG::RNG(): impl_(std::make_unique<Impl>()) {}
RNG::~RNG() = default;
void RNG::GenerateUInt(uint* out, size_t count)
{
impl_->GenerateUInt(out, count);
}
template<typename T>
void RNG::GenerateUniform(T* out, size_t count, float scale, float shift)
{
std::cout << count << std::endl;
impl_->GenerateUniform(out, count, scale, shift);
}
template<typename T>
void RNG::GenerateNormal(T* out, size_t count, float scale, float shift)
{
impl_->GenerateNormal(out, count, scale, shift);
}
template void RNG::GenerateUniform(half* out, size_t count, float scale, float shift);
template void RNG::GenerateUniform(float* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(half* out, size_t count, float scale, float shift);
template void RNG::GenerateNormal(float* out, size_t count, float scale, float shift);
template<typename T>
struct SATypeConverter {
using Type = T;
};
template<>
struct SATypeConverter<half> {
using Type = uint16_t;
};
template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t st)
{
using DataType = typename SATypeConverter<T>::Type;
// Prepare the parameters.
Masked_multihead_attention_params<DataType> params{};
params.q_bias = reinterpret_cast<const DataType*>(p.q_bias);
params.k_bias = reinterpret_cast<const DataType*>(p.k_bias);
params.v_bias = reinterpret_cast<const DataType*>(p.v_bias);
// Set the output buffer.
params.out = reinterpret_cast<DataType*>(p.out);
// Set the input buffers.
// [B, nH + kvH, D]
params.q = reinterpret_cast<const DataType*>(p.q);
params.k = reinterpret_cast<const DataType*>(p.k);
params.v = reinterpret_cast<const DataType*>(p.v);
params.stride = p.stride;
params.finished = (bool*)p.finished;
params.k_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_k_cache);
params.v_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_v_cache);
params.kv_cache_per_sample_offset = p.layer_offset;
params.batch_size = p.batch_size;
params.beam_width = 1;
params.memory_max_len = p.max_seq_len;
params.prefix_prompt_lengths = 0;
params.max_prefix_prompt_length = 0;
params.length_per_sample = p.per_sample_length; // max_input_length + current output length
for (int i = 0; i < p.batch_size; ++i) {
params.timestep = std::max(p.per_sample_length[i], params.timestep);
}
std::cout << "timestep = " << params.timestep << "\n";
params.num_heads = p.num_heads;
params.num_kv_heads = p.num_kv_heads;
params.hidden_size_per_head = p.size_per_head;
params.rotary_embedding_dim = p.rotary_embedding_dim;
params.max_position_embeddings = p.max_position_embeddings;
params.use_dynamic_ntk = false;
params.use_logn_attn = p.use_logn_attn;
// Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * 1.f);
params.int8_mode = 0;
masked_multihead_attention(params, st);
}
template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params, cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "decoder_multihead_attention.h"
#include "src/turbomind/macro.h"
#include <cuda_fp16.h>
#include <memory>
namespace turbomind {
template<typename T>
void Compare(
const T* src, const T* ref, size_t stride, int m, int n, bool show = false, float rtol = 1e-2, float atol = 1e-4);
void LoadBinary(const std::string& path, size_t size, void* dst);
class RNG {
public:
RNG();
~RNG();
void GenerateUInt(uint* out, size_t count);
template<typename T>
void GenerateUniform(T* out, size_t count, float scale = 1.f, float shift = 0.f);
template<typename T>
void GenerateNormal(T* out, size_t count, float scale = 1.f, float shift = 0.f);
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params, cudaStream_t st);
} // namespace turbomind
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include "../gemm_s_f16/common.h"
namespace turbomind {
template<int C, int S, int AccessC, int WarpCount>
struct ThreadMapQ {
static constexpr int kWarpCount = WarpCount;
static constexpr int kAccessC = AccessC;
static constexpr int kWarpThreadC = C / kAccessC;
static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
static_assert(kWarpThreadC <= WARP_SIZE);
static constexpr int kWarpAccessC = kWarpThreadC * kAccessC; // C
static constexpr int kWarpAccessS = kWarpThreadS;
static constexpr int kWarpIterC = C / kWarpAccessC; // 1
static constexpr int kWarpIterS = S / kWarpAccessS;
static constexpr int kWarpC = 1;
static constexpr int kWarpS = kWarpCount;
static constexpr int kIterC = kWarpIterC / kWarpC; // 1
static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
static constexpr int kFootprintC = kWarpAccessC * kIterC; // C
static constexpr int kFootprintS = kWarpAccessS * kIterS;
static constexpr int kDeltaC = kWarpAccessC;
static constexpr int kDeltaS = kWarpAccessS;
__device__ static int2 get_offset(int warp_id, int lane_id)
{
int warp_offset_c = warp_id % kWarpC;
int warp_offset_s = warp_id / kWarpC;
int warp_thread_offset_c = lane_id % kWarpThreadC;
int warp_thread_offset_s = lane_id / kWarpThreadC;
int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
return {cta_thread_offset_c, cta_thread_offset_s};
}
};
template<int C, int S, int AccessC, int WarpThreadC, int WarpCount>
struct ThreadMapKv {
static constexpr int kC = C;
static constexpr int kS = S;
static constexpr int kWarpCount = WarpCount;
static constexpr int kAccessC = AccessC;
static constexpr int kWarpThreadC = WarpThreadC;
static constexpr int kWarpThreadS = WARP_SIZE / kWarpThreadC;
static_assert(kWarpThreadC <= WARP_SIZE);
static constexpr int kWarpAccessC = kWarpThreadC * kAccessC;
static constexpr int kWarpAccessS = kWarpThreadS;
static constexpr int kWarpIterC = C / kWarpAccessC;
static constexpr int kWarpIterS = S / kWarpAccessS;
static constexpr int kWarpC = 1;
static constexpr int kWarpS = kWarpCount;
static constexpr int kIterC = kWarpIterC / kWarpC;
static constexpr int kIterS = std::max(kWarpIterS / kWarpS, 1);
static constexpr int kFootprintC = kWarpAccessC * kIterC;
static constexpr int kFootprintS = kWarpAccessS * kIterS;
static constexpr int kDeltaC = kWarpAccessC;
static constexpr int kDeltaS = kWarpAccessS;
__device__ static int2 get_offset(int warp_id, int lane_id)
{
int warp_offset_c = warp_id % kWarpC;
int warp_offset_s = warp_id / kWarpC;
int warp_thread_offset_c = lane_id % kWarpThreadC;
int warp_thread_offset_s = lane_id / kWarpThreadC;
int cta_thread_offset_c = kFootprintC * warp_offset_c + warp_thread_offset_c * kAccessC;
int cta_thread_offset_s = kFootprintS * warp_offset_s + warp_thread_offset_s;
return {cta_thread_offset_c, cta_thread_offset_s};
}
};
} // namespace turbomind
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <cstddef>
#include <cstdint> #include <cstdint>
namespace turbomind { namespace turbomind {
...@@ -236,7 +237,13 @@ struct IteratorA { ...@@ -236,7 +237,13 @@ struct IteratorA {
__device__ void prefetch(bool mask) __device__ void prefetch(bool mask)
{ {
#if TURBOMIND_ARCH_SM80
cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask); cp_async_cg_A(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
}
#endif
} }
}; };
...@@ -417,7 +424,13 @@ struct IteratorQ { ...@@ -417,7 +424,13 @@ struct IteratorQ {
__device__ void prefetch(bool mask) __device__ void prefetch(bool mask)
{ {
#if TURBOMIND_ARCH_SM80
cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask); cp_async_ca(smem_int_ptr_ + dst_offset_, (const AccessType*)src_ + src_offset_, mask);
#else
if (mask) {
*(AccessType*)((uint8_t*)smem_ + dst_offset_) = __ldg((const AccessType*)src_ + src_offset_);
}
#endif
} }
}; };
...@@ -613,8 +626,14 @@ struct IteratorB { ...@@ -613,8 +626,14 @@ struct IteratorB {
__device__ void prefetch(bool mask) __device__ void prefetch(bool mask)
{ {
#if TURBOMIND_ARCH_SM80
cp_async_cg_B( cp_async_cg_B(
smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask); smem_int_ptr_ + tmp_dst_offset_, (const AccessType*)(src_ + tmp_src_offset_), is_valid_n_ && mask);
#else
if (is_valid_n_ && mask) {
*(AccessType*)((uint8_t*)smem_ + tmp_dst_offset_) = __ldg((const AccessType*)(src_ + tmp_src_offset_));
}
#endif
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment