Commit 1be9a629 authored by zhangshao's avatar zhangshao
Browse files

pa优化,编译选项优化

parent d4c0015a
...@@ -4,11 +4,13 @@ project(vllm_extensions LANGUAGES CXX) ...@@ -4,11 +4,13 @@ project(vllm_extensions LANGUAGES CXX)
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda") option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
set(CMAKE_BUILD_TYPE "Release")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
add_compile_options(-w)
# #
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
...@@ -120,10 +122,11 @@ endif() ...@@ -120,10 +122,11 @@ endif()
# the supported versions for the current language. # the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`. # The final set of arches is stored in `VLLM_GPU_ARCHES`.
# #
override_gpu_arches(VLLM_GPU_ARCHES #override_gpu_arches(VLLM_GPU_ARCHES
${VLLM_GPU_LANG} # ${VLLM_GPU_LANG}
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") # "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
set(VLLM_GPU_ARCHES "gfx928")
message(STATUS "${VLLM_GPU_ARCHES}")
# #
# Query torch for additional GPU compilation flags for the given # Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`. # `VLLM_GPU_LANG`.
......
...@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))" "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags") "Failed to determine torch nvcc compiler flags")
list(REMOVE_ITEM GPU_FLAGS
"-DUSE_ROCM=1"
)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
# "-DENABLE_FP8" # "-DENABLE_FP8"
...@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc" "-fno-gpu-rdc"
"--gpu-max-threads-per-block=1024") "--gpu-max-threads-per-block=1024")
message(STATUS "${GPU_FLAGS}")
endif() endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction() endfunction()
......
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
...@@ -39,6 +20,8 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -39,6 +20,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
...@@ -86,7 +69,9 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -86,7 +69,9 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning. int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0,std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
...@@ -98,7 +83,39 @@ __device__ void paged_attention_kernel( ...@@ -98,7 +83,39 @@ __device__ void paged_attention_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0,std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
...@@ -108,9 +125,9 @@ __device__ void paged_attention_kernel( ...@@ -108,9 +125,9 @@ __device__ void paged_attention_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.y;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
...@@ -121,7 +138,7 @@ __device__ void paged_attention_kernel( ...@@ -121,7 +138,7 @@ __device__ void paged_attention_kernel(
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
...@@ -144,22 +161,38 @@ __device__ void paged_attention_kernel( ...@@ -144,22 +161,38 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; // const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x; int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
const int num_heads = gridDim.x; int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx)
: "v"(warp_id_vec)
:);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; // const float alibi_slope =
const float alibi_slope = // alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread // The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a // group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 / // thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(32 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type; using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
...@@ -176,61 +209,89 @@ __device__ void paged_attention_kernel( ...@@ -176,61 +209,89 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the query, and the second thread // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; // const scalar_t* q_ptr = q + seq_idx * q_stride;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; __shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
i += NUM_THREAD_GROUPS) { // #pragma unroll
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; // for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
q_vecs[thread_group_offset][i] = // i += NUM_THREAD_GROUPS) {
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); // const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
} // q_vecs[thread_group_offset][i] =
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a // *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// memory wall right before we use q_vecs // }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS]; __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE // x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time. // Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX; float qk_max[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX;
}
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote = const bool is_remote =
...@@ -254,8 +315,8 @@ __device__ void paged_attention_kernel( ...@@ -254,8 +315,8 @@ __device__ void paged_attention_kernel(
continue; continue;
} }
} }
const int64_t physical_block_number = const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
static_cast<int64_t>(block_table[block_idx]); const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
...@@ -263,99 +324,104 @@ __device__ void paged_attention_kernel( ...@@ -263,99 +324,104 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the key, and the second thread // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) {
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>( k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale); k_vec_quant, kv_scale);
}
} }
} }
__builtin_amdgcn_sched_barrier(0);
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot( float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk; logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
} }
} }
} }
}
}
// Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
#pragma unroll for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { const int head_idx = head_idx_soffset + reuse_kv_idx;
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); if(!odd_nheads || head_idx < q_boundary) {
} #pragma unroll
if (lane == 0) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
red_smem[warp_idx] = qk_max; qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
} }
__syncthreads(); if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
// TODO(woosuk): Refactor this part. }
// Get the max qk value for the sequence. __syncthreads();
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll // TODO(woosuk): Refactor this part.
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { // Get the max qk value for the sequence.
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
} #pragma unroll
// Broadcast the max qk value to all threads. for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = VLLM_SHFL_SYNC(qk_max, 0); qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
// Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
// Get the sum of the exp values. for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float exp_sum = 0.f; float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[(reuse_kv_idx * partition_size) + i] = val;
float val = __expf(logits[i] - qk_max); exp_sum[reuse_kv_idx] += val;
logits[i] = val; }
exp_sum += val; exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum; logits[(reuse_kv_idx * partition_size) + i] *= inv_sum;
} }
__syncthreads(); __syncthreads();
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum[reuse_kv_idx];
}
}
} }
// Each thread will fetch 16 bytes from the value cache at a time. // Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
...@@ -369,44 +435,74 @@ __device__ void paged_attention_kernel( ...@@ -369,44 +435,74 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f;
}
}
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) { block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number = const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
V_vec v_vec;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if(!odd_nheads || head_idx < q_boundary) {
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
...@@ -428,20 +524,41 @@ __device__ void paged_attention_kernel( ...@@ -428,20 +524,41 @@ __device__ void paged_attention_kernel(
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
accs[i] += dot(logits_vec, v_vec); // if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i]; float acc = accs[reuse_kv_idx][i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[i] = acc; accs[reuse_kv_idx][i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for // NOTE(woosuk): A barrier is required because the shared memory space for
...@@ -455,12 +572,12 @@ __device__ void paged_attention_kernel( ...@@ -455,12 +572,12 @@ __device__ void paged_attention_kernel(
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i]; dst[row_idx] = accs[reuse_kv_idx][i];
} }
} }
} }
...@@ -468,12 +585,12 @@ __device__ void paged_attention_kernel( ...@@ -468,12 +585,12 @@ __device__ void paged_attention_kernel(
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE]; const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx]; accs[reuse_kv_idx][i] += src[row_idx];
} }
} }
} }
...@@ -489,23 +606,29 @@ __device__ void paged_attention_kernel( ...@@ -489,23 +606,29 @@ __device__ void paged_attention_kernel(
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]); from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
} }
} }
} }
}
}
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE> int REUSE_KV_TIMES = 1,
__global__ void paged_attention_v1_kernel( bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -516,22 +639,24 @@ __global__ void paged_attention_v1_kernel( ...@@ -516,22 +639,24 @@ __global__ void paged_attention_v1_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE> int REUSE_KV_TIMES,
__global__ void paged_attention_v2_kernel( int PARTITION_SIZE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -542,7 +667,8 @@ __global__ void paged_attention_v2_kernel( ...@@ -542,7 +667,8 @@ __global__ void paged_attention_v2_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
...@@ -552,19 +678,19 @@ __global__ void paged_attention_v2_kernel( ...@@ -552,19 +678,19 @@ __global__ void paged_attention_v2_kernel(
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank, kv_block_stride, kv_head_stride, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel( __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -674,22 +800,32 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -674,22 +800,32 @@ __global__ void paged_attention_v2_reduce_kernel(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \ BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \ KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \ shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \ NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
<<<grid, block, shared_mem_size, stream>>>( \ , dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \ kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
...@@ -705,7 +841,10 @@ void paged_attention_v1_launcher( ...@@ -705,7 +841,10 @@ void paged_attention_v1_launcher(
int q_stride = query.stride(0); int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int num_threads = 128;
if(num_heads!=num_kv_heads){
num_threads =256;
}
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
...@@ -722,48 +861,31 @@ void paged_attention_v1_launcher( ...@@ -722,48 +861,31 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int padded_max_seq_len = REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
int logits_size = padded_max_seq_len * sizeof(float); HEADSIZE_SWITCH(head_size, [&] {
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); NUM_THREADS_SWITCH(num_threads, [&] {
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len OPT_SWITCH(num_heads == num_kv_heads, [&] {
// Keep that in sync with the logic here! constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int shared_mem_size = std::max(logits_size, outputs_size); int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid(num_heads, num_seqs, 1); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
dim3 block(NUM_THREADS); // Keep that in sync with the logic here!
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); int shared_mem_size = ::max(logits_size, outputs_size);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
switch (head_size) { // int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// NOTE(woosuk): To reduce the compilation time, we only compile for the // std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// head sizes that we use in the model. However, we can easily extend this dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
// to support any head size which is a multiple of 16. dim3 block(NUM_THREADS);
case 64: const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
LAUNCH_PAGED_ATTENTION_V1(64); const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
break; LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
case 80: });
LAUNCH_PAGED_ATTENTION_V1(80); });
break; });
case 96: });
LAUNCH_PAGED_ATTENTION_V1(96); });
break;
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...@@ -788,20 +910,25 @@ void paged_attention_v1_launcher( ...@@ -788,20 +910,25 @@ void paged_attention_v1_launcher(
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \ case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
...@@ -826,19 +953,19 @@ void paged_attention_v1( ...@@ -826,19 +953,19 @@ void paged_attention_v1(
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE> \ REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
<<<grid, block, shared_mem_size, stream>>>( \ , dim3(grid), dim3(block), shared_mem_size, stream, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \ PARTITION_SIZE>) \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \ , dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions); max_num_partitions);
...@@ -883,48 +1010,33 @@ void paged_attention_v2_launcher( ...@@ -883,48 +1010,33 @@ void paged_attention_v2_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float); REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
// For paged attention v2 kernel. OPT_SWITCH(num_heads == num_kv_heads, [&] {
dim3 grid(num_heads, num_seqs, max_num_partitions); int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size); int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs); // For paged attention v2 kernel.
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); // dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 block(NUM_THREADS); dim3 grid;
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); grid.y = max_num_partitions;
switch (head_size) { grid.z = num_seqs;
// NOTE(woosuk): To reduce the compilation time, we only compile for the // int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
// head sizes that we use in the model. However, we can easily extend this int shared_mem_size = ::max(logits_size, outputs_size);
// to support any head size which is a multiple of 16. // For paged attention v2 reduce kernel.
case 64: dim3 reduce_grid(num_heads, num_seqs);
LAUNCH_PAGED_ATTENTION_V2(64); int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
break; dim3 block(NUM_THREADS);
case 80: const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
LAUNCH_PAGED_ATTENTION_V2(80); const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
break; LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
case 96: });
LAUNCH_PAGED_ATTENTION_V2(96); });
break; });
case 112: });
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...@@ -949,20 +1061,25 @@ void paged_attention_v2_launcher( ...@@ -949,20 +1061,25 @@ void paged_attention_v2_launcher(
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \ case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
...@@ -992,4 +1109,4 @@ void paged_attention_v2( ...@@ -992,4 +1109,4 @@ void paged_attention_v2(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -26,19 +26,106 @@ ...@@ -26,19 +26,106 @@
namespace vllm { namespace vllm {
// Q*K^T operation. inline __device__ void v_dot2_f32_f16(float& a, const uint32_t & b,const uint32_t & c) {
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0;": "=v"(a): "v"(b), "v"(c), "0"(a));
}
inline __device__ void v_pk_fma_f16(uint32_t& a, const uint32_t & b,const uint32_t & c){
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;": "=v"(a) : "v"(b), "v"(c), "v"(a));
}
inline __device__ void ds_read_b128(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1;": "=v" (a): "v" (offset));
}
inline __device__ void ds_read_b128_sync(uint4& a, uint32_t offset){
asm volatile("ds_read_b128 %0 %1\ns_waitcnt lgkmcnt(1);": "=v" (a): "v" (offset));
}
inline __device__ void lgkmcnt0(){
asm volatile("s_waitcnt lgkmcnt(0);");
}
__device__ inline size_t __nv_cvta_generic_to_shared_impl(const void *__ptr) {
return (size_t)(void __attribute__((address_space(3))) *)__ptr;
}
inline __device__ void v_dot2_f32_f16(float& a,const uint2 & b,const uint2 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
}
inline __device__ void v_dot2_f32_f16(float& a,const uint4 & b,const uint4 & c) {
v_dot2_f32_f16(a, b.x, c.x);
v_dot2_f32_f16(a, b.y, c.y);
v_dot2_f32_f16(a, b.z, c.z);
v_dot2_f32_f16(a, b.w, c.w);
}
inline __device__ float add_half2(uint32_t a){
union {
uint32_t u32;
half u16[2];
} tmp;
tmp.u32=a;
return static_cast<float>(tmp.u16[0]+tmp.u16[1]);
}
inline __device__ void v_pk_fma_f16x8(float& a,const uint4 & b,const uint4 & c) {
uint32_t tmp = mul<uint32_t, uint32_t, uint32_t>(b.x,c.x);
v_pk_fma_f16(tmp,b.y,c.y);
v_pk_fma_f16(tmp,b.z,c.z);
v_pk_fma_f16(tmp,b.w,c.w);
a+=add_half2(tmp);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk =0;
// uint32_t offset = __nv_cvta_generic_to_shared_impl(q);
// const uint4 *k_ptr= reinterpret_cast<const uint4 *>(k);
// // Compute the parallel products for Q*K^T (treat vector lanes separately).
// constexpr int loop=N*sizeof(Vec)/16/2;
// uint4 qt[2];
// #pragma unroll
// for (int ii = 0; ii < loop; ++ii) {
// ds_read_b128(qt[0],offset+16*ii*2);
// ds_read_b128_sync(qt[1],offset+16*(ii*2+1));
// v_dot2_f32_f16(qk,qt[0],k_ptr[ii*2]);
// // v_pk_fma_f16x8(qk,qt[0],k_ptr[ii*2]);
// lgkmcnt0();
// v_dot2_f32_f16(qk,qt[1],k_ptr[ii*2+1]);
// // v_pk_fma_f16x8(qk,qt[1],k_ptr[ii*2+1]);
// }
#pragma unroll
for (int ii = 0; ii < N; ++ii) {
v_dot2_f32_f16(qk,q[ii],k[ii]);
}
// Finalize the reduction across lanes.
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_vpack_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]); A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll #pragma unroll
for (int ii = 1; ii < N; ++ii) { for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec); qk_vec = fma(q[ii], k[ii], qk_vec);
} }
// Finalize the reduction across lanes.
float qk = sum(qk_vec); float qk = sum(qk_vec);
// Finalize the reduction across lanes.
#pragma unroll #pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask); qk += VLLM_SHFL_XOR_SYNC(qk, mask);
...@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template <typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template <typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
// return qk_dot_vpack_<THREAD_GROUP_SIZE>(q, k);
// }
}; };
} // namespace vllm } // namespace vllm
\ No newline at end of file
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
// #define HEADSIZE_SWITCH(HEADDIM, ...) \
// [&] { \
// if (HEADDIM == 64) { \
// constexpr static int HEAD_SIZE = 64; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 80) { \
// constexpr static int HEAD_SIZE = 80; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 96) { \
// constexpr static int HEAD_SIZE = 96; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 112) { \
// constexpr static int HEAD_SIZE = 112; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 128) { \
// constexpr static int HEAD_SIZE = 128; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 256) { \
// constexpr static int HEAD_SIZE = 256; \
// return __VA_ARGS__(); \
// } \
// else { \
// TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
// } \
// }()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment