Commit b374a264 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove pa

parent c0707728
...@@ -249,11 +249,6 @@ set(VLLM_EXT_SRC ...@@ -249,11 +249,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu" "csrc/opt/activation_kernels_opt.cu"
"csrc/attention/attention_kernels_opt.cu"
"csrc/attention/attention_kernels_opt_tc.cu"
"csrc/attention/attention_with_mask_kernels.cu"
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu" "csrc/opt/layernorm_kernels_opt.cu"
# "csrc/layernorm_quant_kernels.cu" # "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu" "csrc/sampler.cu"
......
# <div align="center"><strong>vLLM</strong></div> # <div align="center"><strong>vLLM</strong></div>
## 简介
vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention高效管理kv内存,Continuous batching传入请求,支持很多Hugging Face模型,如LLaMA & LLaMA-2、Qwen、Chatglm2 & Chatglm3等。
## 暂不支持的官方功能
- **量化推理**:目前不支持marlin的权重量化、kv-cache fp8推理方案
- **模块支持**:目前不支持Sliding window attention
## 支持模型结构列表
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ | 支持版本 | 是否优化 |
| :------: | :------: | :------: | :------: |:------: | :------: |:------: |
| LlamaForCausalLM | Llama 3.2, Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes | v0.5.0,Llama 3.2>=v0.6.2 | Yes |
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3,Qwen3-Embedding,Qwen3-Reranker | Yes | - | - | v0.8.4 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
| DeepseekForCausalLM | Deepseek | Yes | No | - | v0.5.0 | Yes |
| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | No | - | v0.6.2 | Yes |
| DeepseekVLV2ForCausalLM | DeepSeek-VL2 | Yes | No | - | v0.7.2 | Yes |
| DeepseekV3ForCausalLM | DeepSeek-V3 | Yes | Yes | - | v0.7.2 | Yes |
| BaiChuanForCausalLM | Baichuan2,Baichuan | Yes | Yes | - | v0.5.0 | Yes |
| BloomForCausalLM | BLOOM | Yes | No | Yes | v0.5.0 | Yes |
| InternLMForCausalLM | InternLM | Yes | No | - | v0.5.0 | Yes |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - | v0.5.0 | Yes |
| FalconForCausalLM | falcon | Yes | No | Yes | v0.5.0 | Yes |
| TeleChat2ForCausalLM | TeleChat2 | Yes | No | - | v0.7.2 | Yes |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - | v0.5.0 | Yes |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - | v0.6.2 | Yes |
| MixtralForCausalLM | Mixtral-8x7B,Mixtral-8x7B-Instruct | Yes | No | - | v0.5.0 | Yes |
| Qwen2MoeForCausalLM | Qwen2-57B-A14B,Qwen2-57B-A14B-Instruct | Yes | No | - | v0.5.0 | No |
| LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | No | - | v0.6.2 | No |
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | No | Yes | v0.6.2 | No |
| Qwen2_5_VLForConditionalGeneration | Qwen.5-VL | Yes | No | Yes | v0.7.2 | No |
| Mistral3ForConditionalGeneration | Mistral3 | Yes | No | - | v0.8.5.post1 | No |
| Gemma3ForConditionalGeneration | Gemma 3 | Yes | - | - | v0.8.5.post1 | No |
| MiniCPMV | MiniCPM-V | Yes | No | - | v0.6.2 | No |
| Phi3VForCausalLM | Phi-3.5-vision | Yes | No | - | v0.6.2 | No |
| BertModel | bge-large-zh-v1.5 | Yes | No | - | v0.7.2 | No |
| XLMRobertaModel | bge-m3 | Yes | No | - | v0.7.2 | No |
| XLMRobertaForSequenceClassification | bge-reranker-v2-m3 | Yes | No | - | v0.7.2 | No |
## 安装 ## 安装
vLLM支持 vLLM支持
+ Python 3.9.
+ Python 3.10. + Python 3.10.
+ Python 3.11.
+ Python 3.12.
### 使用源码编译方式安装 ### 使用源码编译方式安装
#### 编译环境准备 #### 编译环境准备
提供2种环境准备方式:
1. 基于光源pytorch2.5.1基础镜像环境:根据pytorch2.5.1、python、dtk及系统下载对应的镜像版本。
2. 基于现有python环境:安装pytorch2.5.1,pytorch whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch](https://cancon.hpccube.com:65024/4/main/pytorch),根据python、dtk版本,下载对应pytorch2.5.1的whl包。安装命令如下: 基于光源vllm0.9.2基础镜像环境:
```shell ```shell
pip install torch* (下载的torch的whl包) docker pull image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2
pip install setuptools wheel
``` ```
镜像除编译环境外,已包含运行vllm需要的如下HCU依赖:
* DTK驱动:dtk25.04.1
* Pytorch: 2.5.1
* triton: 3.0.0
* lmslim: 0.3.1
* flash_attn: 2.6.1
* flash_mla: 1.0.0
* lightop: 0.5.0
#### 源码编译安装 #### 源码编译安装
1. 下载源码并进入目录
```shell ```shell
git clone http://developer.hpccube.com/codes/OpenDAS/vllm.git # 根据需要的分支进行切换 git clone -b v0.9.2 https://github.com/vllm-project/vllm.git
cd vllm
``` ```
安装依赖:
2. patch生成与执行(若单独打patch执行可忽略):
- 生成
```shell ```shell
pip install -r requirements/rocm.txt diff -Naur v0.9.2 patch-0.9.2+das.opt1.rc2.dtk2504 > patch_vllm.patch
``` ```
- 提供2种源码编译方式(进入vllm目录):
- 执行
```shell
patch -p1 < patch_vllm.patch
``` ```
1. 编译whl包并安装
python setup.py bdist_wheel
cd dist
pip install vllm*
2. 源码编译安装 3. 获取manylinux so并添加
python3 setup.py install (若调试,可使用python3 setup.py develop)
- 需要将该包安装目录下的_C.abi3.so和_moe_C.abi3.so拷贝至/opt/dtk/并添加软链接至vllm
```shell
cp /usr/local/lib/python3.10/dist-packages/vllm/*.so /opt/dtk/
ln -s /opt/dtk/*.so vllm/
``` ```
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
#### 运行基础环境准备 4. 安装依赖:
1、使用上面基于光源pytorch2.5.1基础镜像环境 ```shell
pip install -r requirements/rocm.txt
```
2、根据pytorch2.5.1、python、dtk及系统下载对应的依赖包: 5. 编译及安装
- triton:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton/) - 编译whl包并安装
- flash_attn: [https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn) ```shell
- lmslim: [https://cancon.hpccube.com:65024/4/main/lmslim](https://cancon.hpccube.com:65024/4/main/lmslim) python setup.py bdist_wheel
cd dist
pip install vllm*
```
- 源码编译安装
```shell
pip install . --no-build-isolation
```
#### 注意事项 #### 注意事项
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
...@@ -103,5 +76,4 @@ python3 setup.py install (若调试,可使用python3 setup.py develop) ...@@ -103,5 +76,4 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
- -
## 参考资料 ## 参考资料
- [README_ORIGIN](README_ORIGIN.md)
- [https://github.com/vllm-project/vllm](https://github.com/vllm-project/vllm) - [https://github.com/vllm-project/vllm](https://github.com/vllm-project/vllm)
\ No newline at end of file
...@@ -119,42 +119,6 @@ def main( ...@@ -119,42 +119,6 @@ def main(
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
query, query,
...@@ -173,44 +137,6 @@ def main( ...@@ -173,44 +137,6 @@ def main(
) )
elif version == "v2": elif version == "v2":
if not args.custom_paged_attn: if not args.custom_paged_attn:
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
...@@ -251,24 +177,6 @@ def main( ...@@ -251,24 +177,6 @@ def main(
k_scale, k_scale,
v_scale, v_scale,
) )
elif version == "v12":
from flash_attn import vllm_flash_attn_with_kvcache
vllm_flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=seq_lens,
block_table=block_tables,
softmax_scale=scale,
causal=True,
window_size=sliding_window,
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
return_softmax_lse=False,
k_scale=k_scale,
v_scale=v_scale,
kv_cache_dtype=kv_cache_dtype,
).squeeze(1)
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -298,7 +206,7 @@ if __name__ == "__main__": ...@@ -298,7 +206,7 @@ if __name__ == "__main__":
) )
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
parser.add_argument("--version", type=str, choices=["v1", "v2", "v12"], default="v12") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
...@@ -324,12 +232,6 @@ if __name__ == "__main__": ...@@ -324,12 +232,6 @@ if __name__ == "__main__":
help="Data type for kv cache storage. If 'auto', will use model " help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)") "ROCm (hcu) supports fp8 (=fp8_e4m3)")
parser.add_argument(
"--gc-paged-attn", action="store_true", help="Use gc paged attention"
)
parser.add_argument(
"--tc-paged-attn", action="store_true", help="Use tc paged attention"
)
parser.add_argument( parser.add_argument(
"--custom-paged-attn", action="store_true", help="Use custom paged attention" "--custom-paged-attn", action="store_true", help="Use custom paged attention"
) )
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel_opt(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
if constexpr (sizeof(scalar_t)==2){
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
// const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx)
: "v"(warp_id_vec)
:);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
// const float alibi_slope =
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(32 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// i += NUM_THREAD_GROUPS) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[thread_group_offset][i] =
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX;
}
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) {
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant,
*k_scale_ptr);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale_ptr);
}
}
}
__builtin_amdgcn_sched_barrier(0);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
}
}
}
}
}
// Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
// Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val;
}
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx];
}
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
V_vec v_vec;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if(!odd_nheads || head_idx < q_boundary) {
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
*v_scale_ptr);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale_ptr);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
}
}
// Perform reduction within each warp.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[reuse_kv_idx][i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[reuse_kv_idx][i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[reuse_kv_idx][i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1,
bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i];
}
// Terminate the thread block.
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_threads = 128;
if(num_heads!=num_kv_heads){
num_threads =256;
}
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
NUM_THREADS_SWITCH(num_threads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);}
});
});
});
});
});
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 256, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
grid.z = num_seqs;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);}
});
});
});
});
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE 64
#include "static_switch_tc.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template <bool>
struct AccType {};
template <>
struct AccType<true> {
using type = uint16_t;
};
template <>
struct AccType<false> {
using type = float;
};
template<bool is_half>
using __acc_type = typename AccType<is_half>::type;
std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
static const std::string device_name=get_device_name();
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
}
return 0;
}
static const int PA_USE_V1 = get_env_("PA_USE_V1");
static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE");
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{
half4_t data[2];
};
struct uint8x4x4{
uint8x4_t data[4];
};
template<typename scalar_t>
struct vec2data{
scalar_t data[2];
};
inline __device__ float uint82float(const uint8_t& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
template<bool is_half,bool is_fp8>
inline __device__ half4_t int8x4_to_half4(uint8x4_t x, const float scale) {
half4_t ret;
if constexpr(is_fp8){
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=uint82float(x[i])*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16(uint82float(x[i])*scale);
}
}
}
else{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=(x[i]-128.0f)*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16((x[i]-128.0f)*scale);
}
}
}
return ret;
}
template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
dst[i]=src[i];
}
}
else{
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16 *>(&dst);
#pragma unroll
for(int i=0;i<4;i++){
out[i]=__float2bfloat16(src[i]);
}
}
}
template<bool is_half>
inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
if constexpr (is_half){
asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
}
else{
asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
}
}
template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES> // Zero means no partitioning.
__global__ void paged_attention_kernel_TC(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size]
scalar_t* __restrict__ out_tmp, // [num_seqs, num_heads, max_num_partitions,head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads,
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,int PARTITION_SIZE=0) {
#if defined(__gfx936__) || defined(__gfx928__)
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.x;
const int max_num_partitions = gridDim.x;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
using ACC_TYPE = __acc_type<is_half>;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
const int start_block_idx = partition_idx * num_blocks_per_partition;
const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
constexpr int x = 16 / sizeof(cache_t);
const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16;
const int rows = lane/16;
const float k_scale=*k_scale_ptr;
const float v_scale=*v_scale_ptr;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx=(blockIdx.y / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.y % num_blocks_per_kv) * REUSE_KV_TIMES;
int q_boundary=REUSE_KV_TIMES;
if(num_heads < REUSE_KV_TIMES*gridDim.y && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES;
const int kv_head_idx = head_idx / num_queries_per_kv;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
float alibi_slope[reuse_group]={0.f};
if(alibi_slopes != nullptr){
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<q_boundary) alibi_slope[i]=alibi_slopes[head_idx+reuse_kv_idx];
}
}
float qk_max[reuse_group];
for(int i=0;i<reuse_group;i++){
qk_max[i]=-FLT_MAX;
}
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
half4x2 q_vec;
q_vec.data[0]={0,0,0,0};
q_vec.data[1]={0,0,0,0};
__shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
for(int i=0;i<q_boundary;i++){
if(thread_idx<16){
half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
if constexpr(is_half){
scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
#pragma unroll
for(int k=0;k<8;k++){
from_float(t[k],to_float(t[k])*scale);
}
}
q_vecs[i][thread_idx]=temp;
}
}
__syncthreads();
extern __shared__ char shared_mem[];
ACC_TYPE* logits = reinterpret_cast<ACC_TYPE*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*x;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0};
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll
for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*WARP_SIZE*x);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
}
//tail
{
if(rowid<q_boundary)q_vec=q_vecs[rowid][3*4+rows];
builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
}
}
else{
uint8x4x4 k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+1];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr+WARP_SIZE*x);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+8];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+9];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
}
#pragma unroll
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else {
if constexpr(!is_half) qk_vec[i]*=scale;
}
const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi;
}
const bool mask = (token_idx >= seq_len);
if(mask) from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
else{
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
}
}
}
}
// compute max
#pragma unroll
for (int mask = 8; mask >= 1; mask /= 2) {
#pragma unroll
for(int r=0;r<reuse_group;r++){
qk_max[r]=fmaxf(qk_max[r],__shfl_xor(qk_max[r],mask));
}
}
#pragma unroll
for(int r=0;r<reuse_group;r++){
if(rowid==0&&r*4+rows<q_boundary){
s_max[r*4+rows][warp_idx] = qk_max[r];
}
}
__syncthreads();
__shared__ float max_out[REUSE_KV_TIMES];
__shared__ float expsum_out[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
const int head_idx_ = head_idx + reuse_kv_idx;
float qk_max_tmp = lane < NUM_WARPS ? s_max[reuse_kv_idx][lane] : -FLT_MAX;
float exp_sum = 0.f;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max_tmp = fmaxf(qk_max_tmp, __shfl_xor(qk_max_tmp, mask));
}
qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
}
if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
expsum_out[reuse_kv_idx]=exp_sum;
}
}
__syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if constexpr(REUSE_KV_TIMES<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<REUSE_KV_TIMES;k++)
{
accs[k][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
if(rows==k){
for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
accs[resuseid][i]+=out_vec[resuseid];
}
}
}
}
}
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
floatV_t tmp=out_smem[thread_idx];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] += tmp[i];
}
}
__syncthreads();
}
}
if (warp_idx == 0) {
scalar_t* out_ptr;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_ptr=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
}
}
}
}
#if defined __gfx928__
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[GROUPS][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<GROUPS;k++)
{
accs[k][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
for(int g=0;g<reuse_group;g++){
accs[g*4+k][i]+=out_vec[g];
}
}
}
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
floatV_t tmp=out_smem[thread_idx];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] += tmp[i];
}
}
__syncthreads();
}
}
}
if (warp_idx == 0) {
scalar_t* out_ptr_base;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<q_boundary){
scalar_t* out_ptr = out_ptr_base + reusekvid * out_offset;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){
const int row_idx = rowid+16*k + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[g*4+k][i]);
}
}
}
}
}
}
#else
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t accs[4][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++)
{
accs[k][i] = {0.f,0.f,0.f,0.f};
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,accs[k][i]);
}
}
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(reuse_group * sizeof(float)) )) float;
// Perform reduction across warps.
for(int m=0; m<4; m++) {
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
out_smem[((warp_idx - mid) * 64+lane)*NUM_ROWS_PER_THREAD+k]=*(floatV_t*)(&(accs[m][k]));
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
floatV_t tmp=out_smem[thread_idx*NUM_ROWS_PER_THREAD+k];
#pragma unroll
for (int i = 0; i < reuse_group; i++) {
accs[m][k][i] += tmp[i];
}
}
}
__syncthreads();
}
}
}
if (warp_idx == 0) {
scalar_t* out_ptr_base;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<q_boundary){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){
const int row_idx = rowid+16*k + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[k][i][g]);
}
}
}
}
}
}
#endif
if (USE_PARTITIONING&&thread_idx < q_boundary){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
*(max_logits+offset)=max_out[thread_idx];
*(exp_sums+offset)=expsum_out[thread_idx];
}
#endif
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
__global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kernel_opt_tc(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions,int PARTITION_SIZE=512) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if(num_partitions==1)return;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE;
int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
const float* max_logits_ptr = max_logits + offset;
const float* exp_sums_ptr = exp_sums + offset;
float max_logit = -FLT_MAX;
float global_max_logit = -FLT_MAX;
float global_exp_sum = 0.0f;
if constexpr(NUM_THREADS == 64&& HEAD_SIZE==128){
__shared__ float shared_exp_sums[64];
if(thread_idx<num_partitions){
max_logit = max_logits_ptr[thread_idx];
global_exp_sum = exp_sums_ptr[thread_idx];
global_max_logit = max_logit;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_max_logit = fmaxf(global_max_logit, VLLM_SHFL_XOR_SYNC(global_max_logit, mask));
}
if(thread_idx<num_partitions){
global_exp_sum = global_exp_sum * __expf(max_logit - global_max_logit);
shared_exp_sums[thread_idx] = global_exp_sum;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += VLLM_SHFL_XOR_SYNC(global_exp_sum, mask);
}
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr = tmp_out + offset * HEAD_SIZE;
using half2_t = vec2data<scalar_t>;
float2_t acc = {0.0f, 0.0f};
half2_t acc_half;
for (int j = 0; j < num_partitions; ++j) {
half2_t tout= *(half2_t*)(tmp_out_ptr + j * HEAD_SIZE + thread_idx*2);
float temp_sum=shared_exp_sums[j]*inv_global_exp_sum;
#pragma unroll
for(int i=0;i<2;i++){
acc[i] += to_float(tout.data[i])*temp_sum;
}
}
#pragma unroll
for(int i=0;i<2;i++){
from_float(acc_half.data[i],acc[i]);
}
*(half2_t*)(out_ptr+thread_idx*2)=acc_half;
}
else{
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \
(vllm::paged_attention_kernel_TC< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step,PARTITION_SIZE); \
if (max_num_partitions<=64&&max_num_partitions>1){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
dim3(reduce_grid), dim3(64), 0, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE); \
}else if(max_num_partitions>64){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS>), \
dim3(reduce_grid), dim3(NUM_THREADS), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE);}
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks)
{
reusekv=1;
num_thread=256;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
if(batchsize==1){
if(qheads==52){reusekv=8;return;}
if(qheads==13){reusekv=2;return;}
reusekv=4;return;
}
if(batchsize==64){
if(qheads==13){PARTITION_SIZE=256;num_thread=128;reusekv=8;}
else if(qheads==32){PARTITION_SIZE=1024;reusekv=8;}
else if(qheads==52||qheads==26){reusekv=16;}
else reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
return;
}
}
if(qheads==kvheads){
if(max_seq_len<=8192){
if(batchsize*qheads>=512){
max_num_partitions=1;
num_thread=64;
}
if(qheads==32&&max_seq_len<=1024)max_num_partitions=1;
}
return;
}
if(max_seq_len<800)max_num_partitions=1;
if(qheads>kvheads*4){
if(max_seq_len<=1000||
max_seq_len<1500&&(batchsize>=8&&qheads>=8||batchsize>=64)||
max_seq_len<1900&&batchsize>=8&&qheads==28
)
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(batchsize>100&&max_seq_len>=2000){
if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=1024;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
if(max_num_partitions==1){
if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize;
if(bytes<51200)reusekv=1;
else if(bytes<256000)reusekv=4;
else reusekv=8;
return;
}
if(batchsize<4||batchsize==4&&qheads==8)reusekv=1;
else if(batchsize<32||batchsize<=64&&qheads==8)reusekv=4;
else reusekv=8;
return;
}
if(blocks<150)return;
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return;
}
if(batchsize>100&&max_seq_len>=2000){
if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||(max_seq_len>=2000&&max_seq_len<3900)))reusekv=4;
}
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_blocks=key_cache.size(0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr;
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
constexpr bool is_half = std::is_same<T, uint16_t>::value;
using ACC_TYPE = __acc_type<is_half>;
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(num_seqs>100&&max_num_partitions>16)max_num_partitions=16;
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_USE_V1!=0)max_num_partitions=1;
if(max_num_partitions==1)PARTITION_SIZE=max_seq_len;
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(ACC_TYPE);
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
grid.y = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.x = max_num_partitions;
grid.z = num_seqs;
dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
});
});
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
void paged_attention_v1_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
paged_attention_v2_opt_tc(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_with_mask_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale_ptr);
}
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
// used for tree-style attention
if (attn_masks != nullptr && token_idx < seq_len) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f: qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale_ptr);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_with_mask_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_with_mask_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_with_mask_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_with_mask_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i];
}
// Terminate the thread block.
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_with_mask_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_with_mask_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V1(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V1(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 160:
LAUNCH_PAGED_ATTENTION_V1(160);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_with_mask_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V2(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V2(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 160:
LAUNCH_PAGED_ATTENTION_V2(160);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_with_mask_kernel_opt(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
if constexpr (sizeof(scalar_t)==2){
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
// const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx)
: "v"(warp_id_vec)
:);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
// const float alibi_slope =
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(32 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// i += NUM_THREAD_GROUPS) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[thread_group_offset][i] =
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX;
}
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) {
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale_ptr);
}
}
}
__builtin_amdgcn_sched_barrier(0);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
// used for tree-style attention
if (attn_masks != nullptr && token_idx < seq_len) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
// used for tree-style attention
/*if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}*/
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
}
}
}
}
}
// Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
// Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val;
}
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx];
}
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
V_vec v_vec;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if(!odd_nheads || head_idx < q_boundary) {
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale_ptr);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
}
}
// Perform reduction within each warp.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[reuse_kv_idx][i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[reuse_kv_idx][i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[reuse_kv_idx][i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1,
bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_with_mask_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_with_mask_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i];
}
// Terminate the thread block.
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_with_mask_kernel_opt<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_with_mask_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_threads = 128;
if(num_heads!=num_kv_heads){
num_threads =256;
}
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
NUM_THREADS_SWITCH(num_threads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);}
});
});
});
});
});
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1_opt_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_with_mask_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 256, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
grid.z = num_seqs;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);}
});
});
});
});
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2_opt_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE 64
#include "static_switch_tc.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template <bool>
struct AccType {};
template <>
struct AccType<true> {
using type = uint16_t;
};
template <>
struct AccType<false> {
using type = float;
};
template<bool is_half>
using __acc_type = typename AccType<is_half>::type;
std::string get_device_name();
static const std::string device_name=get_device_name();
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
}
return 0;
}
static const int PA_USE_V1 = get_env_("PA_USE_V1");
static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE");
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{
half4_t data[2];
};
struct uint8x4x4{
uint8x4_t data[4];
};
template<typename scalar_t>
struct vec2data{
scalar_t data[2];
};
inline __device__ float uint82float(const uint8_t& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
template<bool is_half,bool is_fp8>
inline __device__ half4_t int8x4_to_half4(uint8x4_t x, const float scale) {
half4_t ret;
if constexpr(is_fp8){
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=uint82float(x[i])*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16(uint82float(x[i])*scale);
}
}
}
else{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=(x[i]-128.0f)*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16((x[i]-128.0f)*scale);
}
}
}
return ret;
}
template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
dst[i]=src[i];
}
}
else{
__nv_bfloat16* out = reinterpret_cast<__nv_bfloat16 *>(&dst);
#pragma unroll
for(int i=0;i<4;i++){
out[i]=__float2bfloat16(src[i]);
}
}
}
template<bool is_half>
inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
if constexpr (is_half){
asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
}
else{
asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
}
}
template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES> // Zero means no partitioning.
__global__ void paged_attention_kernel_TC_with_mask(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size]
scalar_t* __restrict__ out_tmp, // [num_seqs, num_heads, max_num_partitions,head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads,
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0,int PARTITION_SIZE=0) {
#if defined(__gfx936__) || defined(__gfx928__)
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.x;
const int max_num_partitions = gridDim.x;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
using ACC_TYPE = __acc_type<is_half>;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
const int start_block_idx = partition_idx * num_blocks_per_partition;
const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
constexpr int x = 16 / sizeof(cache_t);
const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16;
const int rows = lane/16;
const float k_scale=*k_scale_ptr;
const float v_scale=*v_scale_ptr;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx=(blockIdx.y / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.y % num_blocks_per_kv) * REUSE_KV_TIMES;
int q_boundary=REUSE_KV_TIMES;
if(num_heads < REUSE_KV_TIMES*gridDim.y && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES;
const int kv_head_idx = head_idx / num_queries_per_kv;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
float alibi_slope[reuse_group]={0.f};
if(alibi_slopes != nullptr){
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<q_boundary) alibi_slope[i]=alibi_slopes[head_idx+reuse_kv_idx];
}
}
float qk_max[reuse_group];
for(int i=0;i<reuse_group;i++){
qk_max[i]=-FLT_MAX;
}
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
half4x2 q_vec;
q_vec.data[0]={0,0,0,0};
q_vec.data[1]={0,0,0,0};
__shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
for(int i=0;i<q_boundary;i++){
if(thread_idx<16){
half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
if constexpr(is_half){
scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
#pragma unroll
for(int k=0;k<8;k++){
from_float(t[k],to_float(t[k])*scale);
}
}
q_vecs[i][thread_idx]=temp;
}
}
__syncthreads();
extern __shared__ char shared_mem[];
ACC_TYPE* logits = reinterpret_cast<ACC_TYPE*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*x;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0};
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll
for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*WARP_SIZE*x);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
}
//tail
{
if(rowid<q_boundary)q_vec=q_vecs[rowid][3*4+rows];
builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
}
}
else{
uint8x4x4 k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+1];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr+WARP_SIZE*x);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+8];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+9];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
}
#pragma unroll
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else {
if constexpr(!is_half) qk_vec[i]*=scale;
}
const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi;
}
// used for tree-style attention
if (attn_masks != nullptr && token_idx < seq_len) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk_vec[i] = -FLT_MAX;
}
}
const bool mask = (token_idx >= seq_len);
if(mask) from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
else{
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
}
}
}
}
// compute max
#pragma unroll
for (int mask = 8; mask >= 1; mask /= 2) {
#pragma unroll
for(int r=0;r<reuse_group;r++){
qk_max[r]=fmaxf(qk_max[r],__shfl_xor(qk_max[r],mask));
}
}
#pragma unroll
for(int r=0;r<reuse_group;r++){
if(rowid==0&&r*4+rows<q_boundary){
s_max[r*4+rows][warp_idx] = qk_max[r];
}
}
__syncthreads();
__shared__ float max_out[REUSE_KV_TIMES];
__shared__ float expsum_out[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
const int head_idx_ = head_idx + reuse_kv_idx;
float qk_max_tmp = lane < NUM_WARPS ? s_max[reuse_kv_idx][lane] : -FLT_MAX;
float exp_sum = 0.f;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max_tmp = fmaxf(qk_max_tmp, __shfl_xor(qk_max_tmp, mask));
}
qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
}
if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
expsum_out[reuse_kv_idx]=exp_sum;
}
}
__syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if constexpr(REUSE_KV_TIMES<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<REUSE_KV_TIMES;k++)
{
accs[k][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
if(rows==k){
for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
accs[resuseid][i]+=out_vec[resuseid];
}
}
}
}
}
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
floatV_t tmp=out_smem[thread_idx];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] += tmp[i];
}
}
__syncthreads();
}
}
if (warp_idx == 0) {
scalar_t* out_ptr;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_ptr=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
}
}
}
}
#if defined __gfx928__
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[GROUPS][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<GROUPS;k++)
{
accs[k][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
for(int g=0;g<reuse_group;g++){
accs[g*4+k][i]+=out_vec[g];
}
}
}
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
floatV_t tmp=out_smem[thread_idx];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] += tmp[i];
}
}
__syncthreads();
}
}
}
if (warp_idx == 0) {
scalar_t* out_ptr_base;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<q_boundary){
scalar_t* out_ptr = out_ptr_base + reusekvid * out_offset;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){
const int row_idx = rowid+16*k + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[g*4+k][i]);
}
}
}
}
}
}
#else
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t accs[4][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++)
{
accs[k][i] = {0.f,0.f,0.f,0.f};
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,accs[k][i]);
}
}
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(reuse_group * sizeof(float)) )) float;
// Perform reduction across warps.
for(int m=0; m<4; m++) {
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
out_smem[((warp_idx - mid) * 64+lane)*NUM_ROWS_PER_THREAD+k]=*(floatV_t*)(&(accs[m][k]));
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
floatV_t tmp=out_smem[thread_idx*NUM_ROWS_PER_THREAD+k];
#pragma unroll
for (int i = 0; i < reuse_group; i++) {
accs[m][k][i] += tmp[i];
}
}
}
__syncthreads();
}
}
}
if (warp_idx == 0) {
scalar_t* out_ptr_base;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<q_boundary){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){
const int row_idx = rowid+16*k + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[k][i][g]);
}
}
}
}
}
}
#endif
if (USE_PARTITIONING&&thread_idx < q_boundary){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
*(max_logits+offset)=max_out[thread_idx];
*(exp_sums+offset)=expsum_out[thread_idx];
}
#endif
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
__global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kernel_opt_tc(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions,int PARTITION_SIZE=512) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if(num_partitions==1)return;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE;
int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
const float* max_logits_ptr = max_logits + offset;
const float* exp_sums_ptr = exp_sums + offset;
float max_logit = -FLT_MAX;
float global_max_logit = -FLT_MAX;
float global_exp_sum = 0.0f;
if constexpr(NUM_THREADS == 64&& HEAD_SIZE==128){
__shared__ float shared_exp_sums[64];
if(thread_idx<num_partitions){
max_logit = max_logits_ptr[thread_idx];
global_exp_sum = exp_sums_ptr[thread_idx];
global_max_logit = max_logit;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_max_logit = fmaxf(global_max_logit, VLLM_SHFL_XOR_SYNC(global_max_logit, mask));
}
if(thread_idx<num_partitions){
global_exp_sum = global_exp_sum * __expf(max_logit - global_max_logit);
shared_exp_sums[thread_idx] = global_exp_sum;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += VLLM_SHFL_XOR_SYNC(global_exp_sum, mask);
}
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr = tmp_out + offset * HEAD_SIZE;
using half2_t = vec2data<scalar_t>;
float2_t acc = {0.0f, 0.0f};
half2_t acc_half;
for (int j = 0; j < num_partitions; ++j) {
half2_t tout= *(half2_t*)(tmp_out_ptr + j * HEAD_SIZE + thread_idx*2);
float temp_sum=shared_exp_sums[j]*inv_global_exp_sum;
#pragma unroll
for(int i=0;i<2;i++){
acc[i] += to_float(tout.data[i])*temp_sum;
}
}
#pragma unroll
for(int i=0;i<2;i++){
from_float(acc_half.data[i],acc[i]);
}
*(half2_t*)(out_ptr+thread_idx*2)=acc_half;
}
else{
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \
(vllm::paged_attention_kernel_TC_with_mask< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step,attn_masks_ptr,attn_masks_stride,PARTITION_SIZE);\
if (max_num_partitions<=64&&max_num_partitions>1){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
dim3(reduce_grid), dim3(64), 0, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE); \
}else if(max_num_partitions>64){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS>), \
dim3(reduce_grid), dim3(NUM_THREADS), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE);}
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_blocks=key_cache.size(0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr;
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
constexpr bool is_half = std::is_same<T, uint16_t>::value;
using ACC_TYPE = __acc_type<is_half>;
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_USE_V1!=0)max_num_partitions=1;
if(max_num_partitions==1)PARTITION_SIZE=max_seq_len;
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(ACC_TYPE);
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
grid.y = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.x = max_num_partitions;
grid.z = num_seqs;
dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
});
});
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt_tc_with_mask<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step,attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 32) { \
constexpr static int HEAD_SIZE = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM == 64) { \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 80) { \
constexpr static int HEAD_SIZE = 80; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int HEAD_SIZE = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
}else if (NUM_THREAD == 128) { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 32) { \
constexpr static int HEAD_SIZE = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM == 64) { \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 80) { \
constexpr static int HEAD_SIZE = 80; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int HEAD_SIZE = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__();} \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (reusekv==4){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (reusekv==2){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
...@@ -52,125 +52,6 @@ void paged_attention_v2( ...@@ -52,125 +52,6 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// paged_attention with attn_masks
void paged_attention_v1_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v1_opt_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_opt_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void merge_attn_states(torch::Tensor& output, void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_output,
......
...@@ -64,158 +64,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -64,158 +64,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);
// paged_attention with atth_masks
ops.def(
"paged_attention_v1_with_mask("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_with_mask", torch::kCUDA, &paged_attention_v1_with_mask);
// PagedAttention V2.
ops.def(
"paged_attention_v2_with_mask("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_with_mask", torch::kCUDA, &paged_attention_v2_with_mask);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt_with_mask("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_opt_with_mask", torch::kCUDA, &paged_attention_v1_opt_with_mask);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt_with_mask("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_with_mask", torch::kCUDA, &paged_attention_v2_opt_with_mask);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt_tc_with_mask("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_opt_tc_with_mask", torch::kCUDA, &paged_attention_v1_opt_tc_with_mask);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt_tc_with_mask("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask);
// Merge attn states // Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case) // can be used to combine partial attention results (in the split-KV case)
......
...@@ -107,331 +107,6 @@ def paged_attention_v2( ...@@ -107,331 +107,6 @@ def paged_attention_v2(
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v1_with_mask(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_with_mask(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step,attn_masks,
attn_masks_stride)
def paged_attention_v2_with_mask(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
query_start_loc: Optional[torch.Tensor],
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_with_mask(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v1_opt_with_mask(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt_with_mask(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks,
attn_masks_stride)
def paged_attention_v2_opt_with_mask(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt_with_mask(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
# page attention ops (opt)
def paged_attention_v1_opt_tc(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v1_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt_tc(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v2_opt_tc(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt_tc_with_mask(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt_tc_with_mask(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
def paged_attention_v2_opt_tc_with_mask(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt_tc_with_mask(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
# def paged_attention_rocm( # def paged_attention_rocm(
# out: torch.Tensor, # out: torch.Tensor,
# exp_sum: torch.Tensor, # exp_sum: torch.Tensor,
......
...@@ -9,14 +9,12 @@ import torch ...@@ -9,14 +9,12 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import SUPPORT_TC
if HAS_TRITON: if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and SUPPORT_TC
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
...@@ -154,117 +152,11 @@ class PagedAttention: ...@@ -154,117 +152,11 @@ class PagedAttention:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
if use_tc and head_size==128: and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if attn_masks is None:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_tc_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
return output
use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
if attn_masks is None:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v1_opt_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
...@@ -287,30 +179,6 @@ class PagedAttention: ...@@ -287,30 +179,6 @@ class PagedAttention:
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step blocksparse_head_sliding_step
) )
else:
ops.paged_attention_v1_with_mask(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
...@@ -326,66 +194,6 @@ class PagedAttention: ...@@ -326,66 +194,6 @@ class PagedAttention:
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V2 SIZE:")
print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
if attn_masks is None:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
...@@ -411,33 +219,6 @@ class PagedAttention: ...@@ -411,33 +219,6 @@ class PagedAttention:
blocksparse_block_size, blocksparse_block_size,
blocksparse_head_sliding_step blocksparse_head_sliding_step
) )
else:
ops.paged_attention_v2_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
return output return output
@staticmethod @staticmethod
......
...@@ -147,8 +147,6 @@ if TYPE_CHECKING: ...@@ -147,8 +147,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_SPEC_DECODE_EAGER: bool = False VLLM_SPEC_DECODE_EAGER: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16 VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
...@@ -1017,16 +1015,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1017,16 +1015,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN":
lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
("true", "1")),
# If set, vLLM will disable the draft model in cudagraph mode. # If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_SPEC_DECODE_EAGER": "VLLM_SPEC_DECODE_EAGER":
lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))), lambda: bool(int(os.getenv("VLLM_SPEC_DECODE_EAGER", "0"))),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment