Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1be9a629
Commit
1be9a629
authored
Jul 22, 2024
by
zhangshao
Browse files
pa优化,编译选项优化
parent
d4c0015a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
623 additions
and
313 deletions
+623
-313
CMakeLists.txt
CMakeLists.txt
+8
-5
cmake/utils.cmake
cmake/utils.cmake
+5
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+418
-301
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+98
-6
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+94
-0
No files found.
CMakeLists.txt
View file @
1be9a629
...
@@ -4,11 +4,13 @@ project(vllm_extensions LANGUAGES CXX)
...
@@ -4,11 +4,13 @@ project(vllm_extensions LANGUAGES CXX)
option
(
VLLM_TARGET_DEVICE
"Target device backend for vLLM"
"cuda"
)
option
(
VLLM_TARGET_DEVICE
"Target device backend for vLLM"
"cuda"
)
set
(
CMAKE_BUILD_TYPE
"Release"
)
message
(
STATUS
"Build type:
${
CMAKE_BUILD_TYPE
}
"
)
message
(
STATUS
"Build type:
${
CMAKE_BUILD_TYPE
}
"
)
message
(
STATUS
"Target device:
${
VLLM_TARGET_DEVICE
}
"
)
message
(
STATUS
"Target device:
${
VLLM_TARGET_DEVICE
}
"
)
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/utils.cmake
)
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/utils.cmake
)
add_compile_options
(
-w
)
#
#
# Supported python versions. These versions will be searched in order, the
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
# first match will be selected. These should be kept in sync with setup.py.
...
@@ -120,10 +122,11 @@ endif()
...
@@ -120,10 +122,11 @@ endif()
# the supported versions for the current language.
# the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
#
#
override_gpu_arches
(
VLLM_GPU_ARCHES
#override_gpu_arches(VLLM_GPU_ARCHES
${
VLLM_GPU_LANG
}
# ${VLLM_GPU_LANG}
"
${${
VLLM_GPU_LANG
}
_SUPPORTED_ARCHS
}
"
)
# "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
set
(
VLLM_GPU_ARCHES
"gfx928"
)
message
(
STATUS
"
${
VLLM_GPU_ARCHES
}
"
)
#
#
# Query torch for additional GPU compilation flags for the given
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
# `VLLM_GPU_LANG`.
...
...
cmake/utils.cmake
View file @
1be9a629
...
@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -117,6 +117,10 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags"
)
"Failed to determine torch nvcc compiler flags"
)
list
(
REMOVE_ITEM GPU_FLAGS
"-DUSE_ROCM=1"
)
list
(
APPEND GPU_FLAGS
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DUSE_ROCM"
# "-DENABLE_FP8"
# "-DENABLE_FP8"
...
@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -124,7 +128,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"-U__HIP_NO_HALF_OPERATORS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
"-fno-gpu-rdc"
"--gpu-max-threads-per-block=1024"
)
"--gpu-max-threads-per-block=1024"
)
message
(
STATUS
"
${
GPU_FLAGS
}
"
)
endif
()
endif
()
set
(
${
OUT_GPU_FLAGS
}
${
GPU_FLAGS
}
PARENT_SCOPE
)
set
(
${
OUT_GPU_FLAGS
}
${
GPU_FLAGS
}
PARENT_SCOPE
)
endfunction
()
endfunction
()
...
...
csrc/attention/attention_kernels.cu
View file @
1be9a629
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
...
@@ -39,6 +20,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
...
@@ -39,6 +20,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize
#define WARP_SIZE warpSize
#endif
#endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
...
@@ -86,7 +69,9 @@ inline __device__ float block_sum(float* red_smem, float sum) {
...
@@ -86,7 +69,9 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
,
std
::
enable_if_t
<!
std
::
is_same
<
scalar_t
,
uint16_t
>
::
value
,
int
>
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
...
@@ -98,7 +83,39 @@ __device__ void paged_attention_kernel(
...
@@ -98,7 +83,39 @@ __device__ void paged_attention_kernel(
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
=
1
,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
,
std
::
enable_if_t
<
std
::
is_same
<
scalar_t
,
uint16_t
>
::
value
,
int
>
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
...
@@ -108,9 +125,9 @@ __device__ void paged_attention_kernel(
...
@@ -108,9 +125,9 @@ __device__ void paged_attention_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
max_num_partitions
=
gridDim
.
z
;
const
int
max_num_partitions
=
gridDim
.
y
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
seq_len
=
seq_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
...
@@ -121,7 +138,7 @@ __device__ void paged_attention_kernel(
...
@@ -121,7 +138,7 @@ __device__ void paged_attention_kernel(
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
...
@@ -144,22 +161,38 @@ __device__ void paged_attention_kernel(
...
@@ -144,22 +161,38 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
// const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
head_idx
=
blockIdx
.
x
;
int
warp_id_vec
=
threadIdx
.
x
/
WARP_SIZE
;
//warp id in a block
const
int
num_heads
=
gridDim
.
x
;
int
warp_idx
=
0
;
asm
volatile
(
"v_readfirstlane_b32 %0,%1"
:
"=s"
(
warp_idx
)
:
"v"
(
warp_id_vec
)
:
);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
// const float alibi_slope =
const
float
alibi_slope
=
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
// (4 * sizeof(half)) == 2.
constexpr
int
VEC_SIZE
=
MAX
(
16
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
constexpr
int
VEC_SIZE
=
MAX
(
32
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
...
@@ -176,61 +209,89 @@ __device__ void paged_attention_kernel(
...
@@ -176,61 +209,89 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
// q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
// const scalar_t* q_ptr = q + seq_idx * q_stride;
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
const
scalar_t
*
q_ptr_offset
=
q
+
seq_idx
*
q_stride
;
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
__shared__
Q_vec
q_vecs
[
REUSE_KV_TIMES
*
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
i
+=
NUM_THREAD_GROUPS
)
{
// #pragma unroll
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
q_vecs
[
thread_group_offset
][
i
]
=
// i += NUM_THREAD_GROUPS) {
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
}
// q_vecs[thread_group_offset][i] =
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// memory wall right before we use q_vecs
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning.
// Memory planning.
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Workspace for reduction.
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
__shared__
float
red_smem
[
REUSE_KV_TIMES
][
2
*
NUM_WARPS
];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
// Each thread group fetches x elements from the key at a time.
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
float
qk_max
=
-
FLT_MAX
;
float
qk_max
[
REUSE_KV_TIMES
];
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
qk_max
[
reuse_kv_idx
]
=
-
FLT_MAX
;
}
const
int
num_blocks_per_kv
=
((
num_queries_per_kv
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
);
const
int
head_idx_soffset
=
(
blockIdx
.
x
/
num_blocks_per_kv
)
*
num_queries_per_kv
+
(
blockIdx
.
x
%
num_blocks_per_kv
)
*
REUSE_KV_TIMES
;
const
int
kv_head_idx
=
head_idx_soffset
/
num_queries_per_kv
;
const
int
q_boundary
=
(
kv_head_idx
+
1
)
*
num_queries_per_kv
;
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const
scalar_t
*
q_ptr
=
q_ptr_offset
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
][
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
}
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks.
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// For blocksparse attention: skip computation on blocks that are not
// attended
// attended
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
if
constexpr
(
IS_BLOCK_SPARSE
)
{
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
bool
is_remote
=
const
bool
is_remote
=
...
@@ -254,8 +315,8 @@ __device__ void paged_attention_kernel(
...
@@ -254,8 +315,8 @@ __device__ void paged_attention_kernel(
continue
;
continue
;
}
}
}
}
const
int64_t
physical_block_number
=
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// Each thread in a thread group has a different part of the key.
...
@@ -263,99 +324,104 @@ __device__ void paged_attention_kernel(
...
@@ -263,99 +324,104 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
// has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
if
(
reuse_kv_idx
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
}
else
{
// Vector conversion from Quant_vec to K_vec.
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
kv_scale
);
k_vec_quant
,
kv_scale
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// Compute dot product.
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
reuse_kv_idx
*
THREAD_GROUP_SIZE
+
thread_group_offset
],
k_vecs
);
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
__builtin_amdgcn_sched_barrier
(
0
);
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
const
bool
mask
=
token_idx
>=
seq_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
(
reuse_kv_idx
*
partition_size
)
+
(
token_idx
-
start_token_idx
)
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_max
[
reuse_kv_idx
]
=
mask
?
qk_max
[
reuse_kv_idx
]
:
fmaxf
(
qk_max
[
reuse_kv_idx
]
,
qk
);
}
}
}
}
}
}
}
}
// Get the sum of the exp values.
float
exp_sum
[
REUSE_KV_TIMES
]
=
{
0.
f
};
// Perform reduction across the threads in the same warp to get the
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
}
#pragma unroll
if
(
lane
==
0
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
red_smem
[
warp_idx
]
=
qk_max
;
qk_max
[
reuse_kv_idx
]
=
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
}
}
__syncthreads
();
if
(
lane
==
0
)
{
red_smem
[
reuse_kv_idx
][
warp_idx
]
=
qk_max
[
reuse_kv_idx
];
// TODO(woosuk): Refactor this part.
}
// Get the max qk value for the sequence.
__syncthreads
();
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
// TODO(woosuk): Refactor this part.
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
// Get the max qk value for the sequence.
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
[
reuse_kv_idx
]
=
lane
<
NUM_WARPS
?
red_smem
[
reuse_kv_idx
][
lane
]
:
-
FLT_MAX
;
}
#pragma unroll
// Broadcast the max qk value to all threads.
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
qk_max
[
reuse_kv_idx
]
=
fmaxf
(
qk_max
[
reuse_kv_idx
],
VLLM_SHFL_XOR_SYNC
(
qk_max
[
reuse_kv_idx
],
mask
));
}
// Broadcast the max qk value to all threads.
qk_max
[
reuse_kv_idx
]
=
VLLM_SHFL_SYNC
(
qk_max
[
reuse_kv_idx
],
0
);
// Get the sum of the exp values.
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
exp_sum
=
0.
f
;
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max
[
reuse_kv_idx
]);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
exp_sum
[
reuse_kv_idx
]
+=
val
;
logits
[
i
]
=
val
;
}
exp_sum
+=
val
;
exp_sum
[
reuse_kv_idx
]
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
reuse_kv_idx
][
NUM_WARPS
],
exp_sum
[
reuse_kv_idx
]);
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
[
reuse_kv_idx
]
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
logits
[
(
reuse_kv_idx
*
partition_size
)
+
i
]
*=
inv_sum
;
}
}
__syncthreads
();
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
;
*
max_logits_ptr
=
qk_max
[
reuse_kv_idx
];
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
*
exp_sums_ptr
=
exp_sum
[
reuse_kv_idx
];
}
}
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
...
@@ -369,44 +435,74 @@ __device__ void paged_attention_kernel(
...
@@ -369,44 +435,74 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
float
accs
[
REUSE_KV_TIMES
][
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
}
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
reuse_kv_idx
][
i
]
=
0.
f
;
}
}
scalar_t
zero_value
;
scalar_t
zero_value
;
zero
(
zero_value
);
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
}
}
const
int64_t
physical_block_number
=
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
start_token_idx
));
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
V_vec
v_vec
;
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
}
}
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
(
reuse_kv_idx
*
partition_size
)
+
token_idx
-
start_token_idx
));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if
(
reuse_kv_idx
==
0
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
...
@@ -428,20 +524,41 @@ __device__ void paged_attention_kernel(
...
@@ -428,20 +524,41 @@ __device__ void paged_attention_kernel(
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs
[
reuse_kv_idx
][
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
}
}
}
}
}
// Perform reduction within each warp.
// Perform reduction within each warp.
#pragma unroll
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
REUSE_KV_TIMES
;
reuse_kv_idx
++
)
{
int
head_idx
=
head_idx_soffset
+
reuse_kv_idx
;
if
(
!
odd_nheads
||
head_idx
<
q_boundary
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
float
acc
=
accs
[
reuse_kv_idx
][
i
];
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
}
accs
[
i
]
=
acc
;
accs
[
reuse_kv_idx
][
i
]
=
acc
;
}
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// NOTE(woosuk): A barrier is required because the shared memory space for
...
@@ -455,12 +572,12 @@ __device__ void paged_attention_kernel(
...
@@ -455,12 +572,12 @@ __device__ void paged_attention_kernel(
int
mid
=
i
/
2
;
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
float
*
dst
=
&
out_smem
[
(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
i
];
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
}
}
}
}
}
}
...
@@ -468,12 +585,12 @@ __device__ void paged_attention_kernel(
...
@@ -468,12 +585,12 @@ __device__ void paged_attention_kernel(
// Lower warps update the output.
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
const
float
*
src
=
&
out_smem
[
(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
i
]
+=
src
[
row_idx
];
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
}
}
}
}
}
}
...
@@ -489,23 +606,29 @@ __device__ void paged_attention_kernel(
...
@@ -489,23 +606,29 @@ __device__ void paged_attention_kernel(
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
reuse_kv_idx
][
i
]);
}
}
}
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
int
REUSE_KV_TIMES
=
1
,
__global__
void
paged_attention_v1_kernel
(
bool
IS_BLOCK_SPARSE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
...
@@ -516,22 +639,24 @@ __global__ void paged_attention_v1_kernel(
...
@@ -516,22 +639,24 @@ __global__ void paged_attention_v1_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
int
REUSE_KV_TIMES
,
__global__
void
paged_attention_v2_kernel
(
int
PARTITION_SIZE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -542,7 +667,8 @@ __global__ void paged_attention_v2_kernel(
...
@@ -542,7 +667,8 @@ __global__ void paged_attention_v2_kernel(
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
int
num_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_kv_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
...
@@ -552,19 +678,19 @@ __global__ void paged_attention_v2_kernel(
...
@@ -552,19 +678,19 @@ __global__ void paged_attention_v2_kernel(
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
float
kv_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
kv_block_stride
,
kv_head_stride
,
kv_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs).
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
...
@@ -674,22 +800,32 @@ __global__ void paged_attention_v2_reduce_kernel(
...
@@ -674,22 +800,32 @@ __global__ void paged_attention_v2_reduce_kernel(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
KV_DTYPE,
REUSE_KV_TIMES,
IS_BLOCK_SPARSE
, odd_nheads
>), \
shared_mem_size); \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL((
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE
>
\
NUM_THREADS, KV_DTYPE,
REUSE_KV_TIMES,
IS_BLOCK_SPARSE
, odd_nheads>)
\
<<<
grid, block, shared_mem_size, stream
>>>(
\
, dim3(
grid
)
,
dim3(
block
)
, shared_mem_size, stream
,
\
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,
num_heads,
num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \
kv_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
int
NUM_THREADS
=
128
>
void
paged_attention_v1_launcher
(
void
paged_attention_v1_launcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
...
@@ -705,7 +841,10 @@ void paged_attention_v1_launcher(
...
@@ -705,7 +841,10 @@ void paged_attention_v1_launcher(
int
q_stride
=
query
.
stride
(
0
);
int
q_stride
=
query
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
num_threads
=
128
;
if
(
num_heads
!=
num_kv_heads
){
num_threads
=
256
;
}
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
assert
(
head_size
%
thread_group_size
==
0
);
...
@@ -722,48 +861,31 @@ void paged_attention_v1_launcher(
...
@@ -722,48 +861,31 @@ void paged_attention_v1_launcher(
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
int
*
seq_lens_ptr
=
seq_lens
.
data_ptr
<
int
>
();
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
padded_max_seq_len
=
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
int
padded_max_seq_len
=
REUSEKV_SWITCH_V1
(
num_heads
*
num_seqs
,
[
&
]
{
DIVIDE_ROUND_UP
(
max_seq_len
,
BLOCK_SIZE
)
*
BLOCK_SIZE
;
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
int
logits_size
=
padded_max_seq_len
*
sizeof
(
float
);
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
NUM_THREADS_SWITCH
(
num_threads
,
[
&
]
{
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
// Keep that in sync with the logic here!
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
sizeof
(
float
);
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
(
num_heads
,
num_seqs
,
1
);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
dim3
block
(
NUM_THREADS
);
// Keep that in sync with the logic here!
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
switch
(
head_size
)
{
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// head sizes that we use in the model. However, we can easily extend this
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
// to support any head size which is a multiple of 16.
dim3
block
(
NUM_THREADS
);
case
64
:
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
LAUNCH_PAGED_ATTENTION_V1
(
64
);
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
break
;
LAUNCH_PAGED_ATTENTION_V1
(
HEAD_SIZE
);
case
80
:
});
LAUNCH_PAGED_ATTENTION_V1
(
80
);
});
break
;
});
case
96
:
});
LAUNCH_PAGED_ATTENTION_V1
(
96
);
});
break
;
case
112
:
LAUNCH_PAGED_ATTENTION_V1
(
112
);
break
;
case
128
:
LAUNCH_PAGED_ATTENTION_V1
(
128
);
break
;
case
192
:
LAUNCH_PAGED_ATTENTION_V1
(
192
);
break
;
case
256
:
LAUNCH_PAGED_ATTENTION_V1
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
break
;
}
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...
@@ -788,20 +910,25 @@ void paged_attention_v1_launcher(
...
@@ -788,20 +910,25 @@ void paged_attention_v1_launcher(
// 1, 2, 4, 64, 128, 256.
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
break; \
}
}
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void
paged_attention_v1
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
...
@@ -826,19 +953,19 @@ void paged_attention_v1(
...
@@ -826,19 +953,19 @@ void paged_attention_v1(
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL((
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE>
\
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>)
\
<<<
grid, block, shared_mem_size, stream
>>>(
\
, dim3(
grid
)
,
dim3(
block
)
, shared_mem_size, stream
,
\
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
value_cache_ptr,
num_heads,
num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL((
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
PARTITION_SIZE>
)
\
<<<
reduce_grid, block, reduce_shared_mem_size, stream
>>>(
\
, dim3(
reduce_grid
)
,
dim3(
block
)
, reduce_shared_mem_size, stream
,
\
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
max_num_partitions);
...
@@ -883,48 +1010,33 @@ void paged_attention_v2_launcher(
...
@@ -883,48 +1010,33 @@ void paged_attention_v2_launcher(
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
int
logits_size
=
PARTITION_SIZE
*
sizeof
(
float
);
REUSEKV_SWITCH
(
num_heads
*
max_num_partitions
*
num_seqs
,
[
&
]
{
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
BOOL_SWITCH
((
num_heads
/
num_kv_heads
%
REUSE_KV_TIMES
!=
0
),
odd_nheads
,
[
&
]
{
HEADSIZE_SWITCH
(
head_size
,
[
&
]
{
// For paged attention v2 kernel.
OPT_SWITCH
(
num_heads
==
num_kv_heads
,
[
&
]
{
dim3
grid
(
num_heads
,
num_seqs
,
max_num_partitions
);
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
sizeof
(
float
);
int
shared_mem_size
=
std
::
max
(
logits_size
,
outputs_size
);
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
// For paged attention v2 reduce kernel.
dim3
reduce_grid
(
num_heads
,
num_seqs
);
// For paged attention v2 kernel.
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3
block
(
NUM_THREADS
);
dim3
grid
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
grid
.
y
=
max_num_partitions
;
switch
(
head_size
)
{
grid
.
z
=
num_seqs
;
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
// head sizes that we use in the model. However, we can easily extend this
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
// to support any head size which is a multiple of 16.
// For paged attention v2 reduce kernel.
case
64
:
dim3
reduce_grid
(
num_heads
,
num_seqs
);
LAUNCH_PAGED_ATTENTION_V2
(
64
);
int
reduce_shared_mem_size
=
2
*
max_num_partitions
*
sizeof
(
float
);
break
;
dim3
block
(
NUM_THREADS
);
case
80
:
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
query
));
LAUNCH_PAGED_ATTENTION_V2
(
80
);
const
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
break
;
LAUNCH_PAGED_ATTENTION_V2
(
HEAD_SIZE
);
case
96
:
});
LAUNCH_PAGED_ATTENTION_V2
(
96
);
});
break
;
});
case
112
:
});
LAUNCH_PAGED_ATTENTION_V2
(
112
);
break
;
case
128
:
LAUNCH_PAGED_ATTENTION_V2
(
128
);
break
;
case
192
:
LAUNCH_PAGED_ATTENTION_V2
(
192
);
break
;
case
256
:
LAUNCH_PAGED_ATTENTION_V2
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
break
;
}
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...
@@ -949,20 +1061,25 @@ void paged_attention_v2_launcher(
...
@@ -949,20 +1061,25 @@ void paged_attention_v2_launcher(
// 1, 2, 4, 64, 128, 256.
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
break; \
}
}
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void
paged_attention_v2
(
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
...
@@ -992,4 +1109,4 @@ void paged_attention_v2(
...
@@ -992,4 +1109,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef WARP_SIZE
#undef MAX
#undef MAX
#undef MIN
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/attention/attention_utils.cuh
View file @
1be9a629
...
@@ -26,19 +26,106 @@
...
@@ -26,19 +26,106 @@
namespace
vllm
{
namespace
vllm
{
// Q*K^T operation.
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint32_t
&
b
,
const
uint32_t
&
c
)
{
asm
volatile
(
"v_dot2_f32_f16 %0, %1, %2, %0;"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
inline
__device__
void
v_pk_fma_f16
(
uint32_t
&
a
,
const
uint32_t
&
b
,
const
uint32_t
&
c
){
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3;"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"v"
(
a
));
}
inline
__device__
void
ds_read_b128
(
uint4
&
a
,
uint32_t
offset
){
asm
volatile
(
"ds_read_b128 %0 %1;"
:
"=v"
(
a
)
:
"v"
(
offset
));
}
inline
__device__
void
ds_read_b128_sync
(
uint4
&
a
,
uint32_t
offset
){
asm
volatile
(
"ds_read_b128 %0 %1
\n
s_waitcnt lgkmcnt(1);"
:
"=v"
(
a
)
:
"v"
(
offset
));
}
inline
__device__
void
lgkmcnt0
(){
asm
volatile
(
"s_waitcnt lgkmcnt(0);"
);
}
__device__
inline
size_t
__nv_cvta_generic_to_shared_impl
(
const
void
*
__ptr
)
{
return
(
size_t
)(
void
__attribute__
((
address_space
(
3
)))
*
)
__ptr
;
}
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint2
&
b
,
const
uint2
&
c
)
{
v_dot2_f32_f16
(
a
,
b
.
x
,
c
.
x
);
v_dot2_f32_f16
(
a
,
b
.
y
,
c
.
y
);
}
inline
__device__
void
v_dot2_f32_f16
(
float
&
a
,
const
uint4
&
b
,
const
uint4
&
c
)
{
v_dot2_f32_f16
(
a
,
b
.
x
,
c
.
x
);
v_dot2_f32_f16
(
a
,
b
.
y
,
c
.
y
);
v_dot2_f32_f16
(
a
,
b
.
z
,
c
.
z
);
v_dot2_f32_f16
(
a
,
b
.
w
,
c
.
w
);
}
inline
__device__
float
add_half2
(
uint32_t
a
){
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u32
=
a
;
return
static_cast
<
float
>
(
tmp
.
u16
[
0
]
+
tmp
.
u16
[
1
]);
}
inline
__device__
void
v_pk_fma_f16x8
(
float
&
a
,
const
uint4
&
b
,
const
uint4
&
c
)
{
uint32_t
tmp
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
b
.
x
,
c
.
x
);
v_pk_fma_f16
(
tmp
,
b
.
y
,
c
.
y
);
v_pk_fma_f16
(
tmp
,
b
.
z
,
c
.
z
);
v_pk_fma_f16
(
tmp
,
b
.
w
,
c
.
w
);
a
+=
add_half2
(
tmp
);
}
// Q*K^T operation. fp16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
inline
__device__
float
qk_dot_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
float
qk
=
0
;
// uint32_t offset = __nv_cvta_generic_to_shared_impl(q);
// const uint4 *k_ptr= reinterpret_cast<const uint4 *>(k);
// // Compute the parallel products for Q*K^T (treat vector lanes separately).
// constexpr int loop=N*sizeof(Vec)/16/2;
// uint4 qt[2];
// #pragma unroll
// for (int ii = 0; ii < loop; ++ii) {
// ds_read_b128(qt[0],offset+16*ii*2);
// ds_read_b128_sync(qt[1],offset+16*(ii*2+1));
// v_dot2_f32_f16(qk,qt[0],k_ptr[ii*2]);
// // v_pk_fma_f16x8(qk,qt[0],k_ptr[ii*2]);
// lgkmcnt0();
// v_dot2_f32_f16(qk,qt[1],k_ptr[ii*2+1]);
// // v_pk_fma_f16x8(qk,qt[1],k_ptr[ii*2+1]);
// }
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
v_dot2_f32_f16
(
qk
,
q
[
ii
],
k
[
ii
]);
}
// Finalize the reduction across lanes.
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
}
return
qk
;
}
// Q*K^T operation. //bf16
// template <int THREAD_GROUP_SIZE, typename Vec, int N, typename scalar_t, std::enable_if_t<!std::is_same<scalar_t, uint16_t>::value, int> = 0>
template
<
int
THREAD_GROUP_SIZE
,
typename
Vec
,
int
N
>
inline
__device__
float
qk_dot_vpack_
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
using
A_vec
=
typename
FloatVec
<
Vec
>::
Type
;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
float
qk
=
sum
(
qk_vec
);
// Finalize the reduction across lanes.
#pragma unroll
#pragma unroll
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
THREAD_GROUP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
qk
+=
VLLM_SHFL_XOR_SYNC
(
qk
,
mask
);
...
@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
...
@@ -46,12 +133,17 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return
qk
;
return
qk
;
}
}
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
template
<
typename
T
,
int
THREAD_GROUP_SIZE
>
struct
Qk_dot
{
struct
Qk_dot
{
template
<
typename
Vec
,
int
N
>
template
<
typename
Vec
,
int
N
>
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
static
inline
__device__
float
dot
(
const
Vec
(
&
q
)[
N
],
const
Vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
return
qk_dot_
<
THREAD_GROUP_SIZE
>
(
q
,
k
);
}
}
// template <typename Vec, int N>
// static inline __device__ float qk_dot_vpack(const Vec (&q)[N], const Vec (&k)[N]) {
// return qk_dot_vpack_<THREAD_GROUP_SIZE>(q, k);
// }
};
};
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/attention/static_switch.h
0 → 100644
View file @
1be9a629
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} \
}()
// #define HEADSIZE_SWITCH(HEADDIM, ...) \
// [&] { \
// if (HEADDIM == 64) { \
// constexpr static int HEAD_SIZE = 64; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 80) { \
// constexpr static int HEAD_SIZE = 80; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 96) { \
// constexpr static int HEAD_SIZE = 96; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 112) { \
// constexpr static int HEAD_SIZE = 112; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 128) { \
// constexpr static int HEAD_SIZE = 128; \
// return __VA_ARGS__(); \
// } else if (HEADDIM == 256) { \
// constexpr static int HEAD_SIZE = 256; \
// return __VA_ARGS__(); \
// } \
// else { \
// TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
// } \
// }()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment