Commit 9f9f3796 authored by zhangshao's avatar zhangshao
Browse files

恢复对bf16的支持

parent f38bd872
......@@ -68,40 +68,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int PARTITION_SIZE = 0,
std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> =
0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int PARTITION_SIZE = 0,
std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> =
0> // Zero means no partitioning.
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
......@@ -133,6 +100,7 @@ __device__ void paged_attention_kernel(
// No work to do. Terminate the thread block.
return;
}
if constexpr (sizeof(scalar_t)==2){
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
......@@ -723,6 +691,7 @@ __device__ void paged_attention_kernel(
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
......
......@@ -80,8 +80,8 @@ inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
// template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk =0;
......@@ -114,9 +114,9 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) {
template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
// template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
......@@ -138,7 +138,7 @@ template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
return qk_dot_<THREAD_GROUP_SIZE,Vec,N,T>(q, k);
}
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment