"vllm/vscode:/vscode.git/clone" did not exist on "07ab160741a486bbef23efbf26aaf2ea8a785ae1"
Commit 9f9f3796 authored by zhangshao's avatar zhangshao
Browse files

恢复对bf16的支持

parent f38bd872
...@@ -68,40 +68,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -68,40 +68,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false, bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int PARTITION_SIZE = 0, int PARTITION_SIZE = 0> // Zero means no partitioning.
std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> =
0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES = 1, bool odd_nheads = false,
int PARTITION_SIZE = 0,
std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> =
0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
...@@ -133,6 +100,7 @@ __device__ void paged_attention_kernel( ...@@ -133,6 +100,7 @@ __device__ void paged_attention_kernel(
// No work to do. Terminate the thread block. // No work to do. Terminate the thread block.
return; return;
} }
if constexpr (sizeof(scalar_t)==2){
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
...@@ -723,6 +691,7 @@ __device__ void paged_attention_kernel( ...@@ -723,6 +691,7 @@ __device__ void paged_attention_kernel(
} }
} }
} }
}
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
......
...@@ -80,8 +80,8 @@ inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c ...@@ -80,8 +80,8 @@ inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c
} }
// Q*K^T operation. fp16 // Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0> template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N> // template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk =0; float qk =0;
...@@ -114,9 +114,9 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -114,9 +114,9 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
} }
// Q*K^T operation. //bf16 // Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0> template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N> // template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]); A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
...@@ -138,7 +138,7 @@ template <typename T, int THREAD_GROUP_SIZE> ...@@ -138,7 +138,7 @@ template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template <typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE,Vec,N,T>(q, k);
} }
// template <typename Vec, int N> // template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) { // static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment