Unverified Commit 6e58fced authored by tpoisonooo's avatar tpoisonooo Committed by GitHub
Browse files

fix(kernel): speed degrade (#41)

* feat(template): remote diff

* feat(cmake): use c++17
parent 8aa6eb10
...@@ -84,7 +84,7 @@ if(USE_TRITONSERVER_DATATYPE) ...@@ -84,7 +84,7 @@ if(USE_TRITONSERVER_DATATYPE)
add_definitions("-DUSE_TRITONSERVER_DATATYPE") add_definitions("-DUSE_TRITONSERVER_DATATYPE")
endif() endif()
set(CXX_STD "14" CACHE STRING "C++ standard") set(CXX_STD "17" CACHE STRING "C++ standard")
set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
......
...@@ -26,10 +26,10 @@ ...@@ -26,10 +26,10 @@
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \ dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS> \ mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params) <<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -40,18 +40,30 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st ...@@ -40,18 +40,30 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st
{ {
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value; constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value; // constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = params.timestep; const int tlength = params.timestep;
FT_CHECK(params.cache_indir == nullptr); FT_CHECK(params.cache_indir == nullptr);
if (tlength < 32) { if (params.int8_mode == 4) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); if (tlength < 32) {
} MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, 4, stream);
else if (tlength < 2048) { }
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); else if (tlength < 2048) {
} MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, 4, stream);
else { }
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, 4, stream);
}
} else {
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, 0, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, 0, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, 0, stream);
}
} }
} }
......
...@@ -17,9 +17,8 @@ ...@@ -17,9 +17,8 @@
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h" #include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h" #include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/turbomind/models/llama/llama_utils.h" // #include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include "src/turbomind/utils/cuda_bf16_wrapper.h" // #include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_type_utils.cuh"
#include <assert.h> #include <assert.h>
#include <float.h> #include <float.h>
...@@ -1272,7 +1271,8 @@ template<typename T, // The type of the inputs. Supported types: float and half ...@@ -1272,7 +1271,8 @@ template<typename T, // The type of the inputs. Supported types: float and half
int THREADS_PER_KEY, // The number of threads per key. int THREADS_PER_KEY, // The number of threads per key.
int THREADS_PER_VALUE, // The number of threads per value. int THREADS_PER_VALUE, // The number of threads per value.
int THREADS_PER_BLOCK, // The number of threads in a threadblock. int THREADS_PER_BLOCK, // The number of threads in a threadblock.
bool HAS_BEAMS> bool HAS_BEAMS,
int QUANT_POLICY> // quantization method
__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> params) __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> params)
{ {
...@@ -1462,16 +1462,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1462,16 +1462,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci; + tlength_circ * QK_ELTS_IN_16B + ci;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache); int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} }
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
} }
else { else {
int offset; int offset;
...@@ -1484,17 +1483,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1484,17 +1483,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
+ co * QK_ELTS_IN_16B + ci; + co * QK_ELTS_IN_16B + ci;
} }
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} }
else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k);
}
} }
} }
} }
...@@ -1575,7 +1573,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1575,7 +1573,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
T* k_cache_batch = nullptr; T* k_cache_batch = nullptr;
int8_t* k_cache_batch_int8 = nullptr; int8_t* k_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
} else if (QUANT_POLICY == 4) {
// convert k_cache_per_sample to int8 // convert k_cache_per_sample to int8
if (params.k_cache_per_sample) { if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
...@@ -1586,14 +1590,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1586,14 +1590,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki]; k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
} }
} }
else {
T* k_cache = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
k_cache_batch = k_cache;
}
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
...@@ -1629,7 +1625,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1629,7 +1625,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
} }
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
} else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type; using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type;
...@@ -1639,10 +1638,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1639,10 +1638,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float); k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
} }
else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
} }
} }
} }
...@@ -1763,7 +1758,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1763,7 +1758,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t* v_cache_int8 = nullptr; int8_t* v_cache_int8 = nullptr;
int8_t* v_cache_batch_int8 = nullptr; int8_t* v_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache;
} else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) { if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]); int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
...@@ -1775,15 +1778,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1775,15 +1778,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
v_cache_batch_int8 = v_cache_int8; v_cache_batch_int8 = v_cache_int8;
} }
else {
v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch = v_cache;
}
// The number of values processed per iteration of the loop. // The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
...@@ -1834,17 +1828,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1834,17 +1828,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
} else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]); *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} }
else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
}
// Load the logits from shared memory. // Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...@@ -1881,18 +1874,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1881,18 +1874,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v; V_vec_k v;
if (not QUANT_POLICY) {
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
} else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]); *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} }
else {
v = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
}
// Load the logits from shared memory. // Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...@@ -1938,14 +1929,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1938,14 +1929,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Store the values with bias back to global memory in the cache for V. // Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v; //*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) { if (not QUANT_POLICY) {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
} else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale); Packed_Int8_t v_int8 = quant(v, v_scale);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8; *reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
} }
else {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
}
} }
// Initialize the output value with the current timestep. // Initialize the output value with the current timestep.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment