Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e053941
Commit
9e053941
authored
Mar 19, 2025
by
zhuwenwen
Browse files
skip fp8 kernel and _rocm_C extension
parent
f850f22a
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
992 additions
and
985 deletions
+992
-985
CMakeLists.txt
CMakeLists.txt
+6
-4
cmake/utils.cmake
cmake/utils.cmake
+1
-1
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+655
-655
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+1
-1
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+237
-235
csrc/ops.h
csrc/ops.h
+15
-15
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+3
-1
csrc/quantization/fused_kernels/quant_conversions.cuh
csrc/quantization/fused_kernels/quant_conversions.cuh
+1
-1
csrc/quantization/gptq/compat.cuh
csrc/quantization/gptq/compat.cuh
+12
-12
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+30
-30
setup.py
setup.py
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+25
-25
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+4
-3
No files found.
CMakeLists.txt
View file @
9e053941
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
...
@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
#
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
#
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
#
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/prepare_inputs/advance_step.cu"
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
...
@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3
USE_SABI 3
WITH_SOABI
)
WITH_SOABI
)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP")
if(VLLM_GPU_LANG STREQUAL "HIP")
#
#
# _rocm_C extension
# _rocm_C extension
...
@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
...
@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3
USE_SABI 3
WITH_SOABI)
WITH_SOABI)
endif()
endif()
]]
# For CUDA we also build and ship some external projects.
# For CUDA we also build and ship some external projects.
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
include
(
cmake/external_projects/flashmla.cmake
)
include
(
cmake/external_projects/flashmla.cmake
)
include
(
cmake/external_projects/vllm_flash_attn.cmake
)
include
(
cmake/external_projects/vllm_flash_attn.cmake
)
endif
()
endif
()
\ No newline at end of file
cmake/utils.cmake
View file @
9e053941
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
...
@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list
(
APPEND GPU_FLAGS
list
(
APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DUSE_ROCM"
"-DENABLE_FP8"
#
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc"
)
"-fno-gpu-rdc"
)
...
...
csrc/attention/attention_kernels.cuh
View file @
9e053941
...
@@ -17,660 +17,660 @@
...
@@ -17,660 +17,660 @@
* limitations under the License.
* limitations under the License.
*/
*/
#include <torch/all.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define WARP_SIZE 32
#define WARP_SIZE 32
#else
#else
#define WARP_SIZE warpSize
#define WARP_SIZE warpSize
#endif
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
namespace
vllm
{
// Utility function for attention softmax.
// Utility function for attention softmax.
template
<
int
NUM_WARPS
>
template
<
int
NUM_WARPS
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
// Compute the sum per warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
}
// Warp leaders store the data to shared memory.
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
red_smem
[
warp
]
=
sum
;
}
}
// Make sure the data is in shared memory.
// Make sure the data is in shared memory.
__syncthreads
();
__syncthreads
();
// The warps compute the final sums.
// The warps compute the final sums.
if
(
lane
<
NUM_WARPS
)
{
if
(
lane
<
NUM_WARPS
)
{
sum
=
red_smem
[
lane
];
sum
=
red_smem
[
lane
];
}
}
// Parallel reduction inside the warp.
// Parallel reduction inside the warp.
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
}
// Broadcast to other threads.
// Broadcast to other threads.
return
VLLM_SHFL_SYNC
(
sum
,
0
);
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
seq_len
=
seq_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
// No work to do. Terminate the thread block.
// No work to do. Terminate the thread block.
return
;
return
;
}
}
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_THREAD_GROUPS
=
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
// divides NUM_THREADS
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
// (4 * sizeof(half)) == 2.
constexpr
int
VEC_SIZE
=
MAX
(
16
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
constexpr
int
VEC_SIZE
=
MAX
(
16
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
constexpr
int
NUM_ELEMS_PER_THREAD
=
HEAD_SIZE
/
THREAD_GROUP_SIZE
;
constexpr
int
NUM_ELEMS_PER_THREAD
=
HEAD_SIZE
/
THREAD_GROUP_SIZE
;
constexpr
int
NUM_VECS_PER_THREAD
=
NUM_ELEMS_PER_THREAD
/
VEC_SIZE
;
constexpr
int
NUM_VECS_PER_THREAD
=
NUM_ELEMS_PER_THREAD
/
VEC_SIZE
;
const
int
thread_group_idx
=
thread_idx
/
THREAD_GROUP_SIZE
;
const
int
thread_group_idx
=
thread_idx
/
THREAD_GROUP_SIZE
;
const
int
thread_group_offset
=
thread_idx
%
THREAD_GROUP_SIZE
;
const
int
thread_group_offset
=
thread_idx
%
THREAD_GROUP_SIZE
;
// Load the query to registers.
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
// q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
thread_group_offset
][
i
]
=
q_vecs
[
thread_group_offset
][
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
}
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// memory wall right before we use q_vecs
// Memory planning.
// Memory planning.
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Workspace for reduction.
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
// Each thread group fetches x elements from the key at a time.
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
float
qk_max
=
-
FLT_MAX
;
float
qk_max
=
-
FLT_MAX
;
// Iterate over the key blocks.
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
// blocksparse specific vars
// blocksparse specific vars
int
bs_block_offset
;
int
bs_block_offset
;
int
q_bs_block_id
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
// sliding on q heads
bs_block_offset
=
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
else
// sliding on kv heads
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
(
-
blocksparse_head_sliding_step
)
+
1
;
1
;
}
}
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// For blocksparse attention: skip computation on blocks that are not
// attended
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
if
constexpr
(
IS_BLOCK_SPARSE
)
{
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
bool
is_remote
=
const
bool
is_remote
=
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
);
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
);
const
bool
is_local
=
const
bool
is_local
=
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
if
(
!
is_remote
&&
!
is_local
)
{
if
(
!
is_remote
&&
!
is_local
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// NOTE(linxihui): assign very large number to skipped tokens to
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
// skipped.
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
}
}
}
}
continue
;
continue
;
}
}
}
}
const
int64_t
physical_block_number
=
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
// has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
}
else
{
// Vector conversion from Quant_vec to K_vec.
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
*
k_scale
);
k_vec_quant
,
*
k_scale
);
}
}
}
}
// Compute dot product.
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
const
bool
mask
=
token_idx
>=
seq_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
}
}
}
}
}
// Perform reduction across the threads in the same warp to get the
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
}
if
(
lane
==
0
)
{
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
qk_max
;
red_smem
[
warp_idx
]
=
qk_max
;
}
}
__syncthreads
();
__syncthreads
();
// TODO(woosuk): Refactor this part.
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
// Get the max qk value for the sequence.
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
}
// Broadcast the max qk value to all threads.
// Broadcast the max qk value to all threads.
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
// Get the sum of the exp values.
// Get the sum of the exp values.
float
exp_sum
=
0.
f
;
float
exp_sum
=
0.
f
;
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
logits
[
i
]
=
val
;
exp_sum
+=
val
;
exp_sum
+=
val
;
}
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
logits
[
i
]
*=
inv_sum
;
}
}
__syncthreads
();
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
;
*
max_logits_ptr
=
qk_max
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
*
exp_sums_ptr
=
exp_sum
;
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
L_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
L_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_quant_vec
=
typename
Vec
<
cache_t
,
V_VEC_SIZE
>::
Type
;
using
V_quant_vec
=
typename
Vec
<
cache_t
,
V_VEC_SIZE
>::
Type
;
using
Float_L_vec
=
typename
FloatVec
<
L_vec
>::
Type
;
using
Float_L_vec
=
typename
FloatVec
<
L_vec
>::
Type
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
float
accs
[
NUM_ROWS_PER_THREAD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
accs
[
i
]
=
0.
f
;
}
}
scalar_t
zero_value
;
scalar_t
zero_value
;
zero
(
zero_value
);
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// For blocksparse attention: skip computation on blocks that are not
// attended
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
continue
;
}
}
}
}
const
int64_t
physical_block_number
=
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
start_token_idx
));
start_token_idx
));
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
kv_head_idx
*
kv_head_stride
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
;
V_vec
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
}
else
{
V_quant_vec
v_quant_vec
=
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
*
v_scale
);
*
v_scale
);
}
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
}
}
}
}
// Perform reduction within each warp.
// Perform reduction within each warp.
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
float
acc
=
accs
[
i
];
#pragma unroll
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
}
accs
[
i
]
=
acc
;
accs
[
i
]
=
acc
;
}
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
// logits is reused for the output.
__syncthreads
();
__syncthreads
();
// Perform reduction across warps.
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
i
];
dst
[
row_idx
]
=
accs
[
i
];
}
}
}
}
}
}
__syncthreads
();
__syncthreads
();
// Lower warps update the output.
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
i
]
+=
src
[
row_idx
];
accs
[
i
]
+=
src
[
row_idx
];
}
}
}
}
}
}
__syncthreads
();
__syncthreads
();
}
}
// Write the final output.
// Write the final output.
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
}
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
bool
IS_BLOCK_SPARSE
>
__global__
void
paged_attention_v1_kernel
(
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
blocksparse_head_sliding_step
);
}
}
// Grid: (num_heads, num_seqs, max_num_partitions).
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
// max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
// max_num_partitions, head_size]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
int
max_num_partitions
)
{
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
const
int
num_heads
=
gridDim
.
x
;
// head_size/x, block_size, x]
const
int
head_idx
=
blockIdx
.
x
;
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
const
int
seq_idx
=
blockIdx
.
y
;
// head_size, block_size]
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
num_kv_heads
,
// [num_heads]
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq_len
,
PARTITION_SIZE
);
const
float
scale
,
if
(
num_partitions
==
1
)
{
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
// No need to reduce. Only copy tmp_out to out.
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
scalar_t
*
out_ptr
=
const
int
max_num_blocks_per_seq
,
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
scalar_t
*
tmp_out_ptr
=
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
}
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
// Terminate the thread block.
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
return
;
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
}
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
blocksparse_head_sliding_step
);
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
}
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Grid: (num_heads, num_seqs).
// Size: 2 * num_partitions.
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
extern
__shared__
char
shared_mem
[];
int
PARTITION_SIZE
>
// Workspace for reduction.
__global__
void
paged_attention_v2_reduce_kernel
(
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// Load max logits to shared memory.
// max_num_partitions]
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
const
float
*
max_logits_ptr
=
max_logits
+
// max_num_partitions]
seq_idx
*
num_heads
*
max_num_partitions
+
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
head_idx
*
max_num_partitions
;
// max_num_partitions, head_size]
float
max_logit
=
-
FLT_MAX
;
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
const
int
max_num_partitions
)
{
const
float
l
=
max_logits_ptr
[
i
];
const
int
num_heads
=
gridDim
.
x
;
shared_max_logits
[
i
]
=
l
;
const
int
head_idx
=
blockIdx
.
x
;
max_logit
=
fmaxf
(
max_logit
,
l
);
const
int
seq_idx
=
blockIdx
.
y
;
}
const
int
seq_len
=
seq_lens
[
seq_idx
];
__syncthreads
();
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// Get the global max logit.
// No need to reduce. Only copy tmp_out to out.
// Reduce within the warp.
scalar_t
*
out_ptr
=
#pragma unroll
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
const
scalar_t
*
tmp_out_ptr
=
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
}
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
if
(
lane
==
0
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
red_smem
[
warp_idx
]
=
max_logit
;
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
}
}
__syncthreads
();
// Terminate the thread block.
// Reduce across warps.
return
;
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
}
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
}
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Broadcast the max value to all threads.
max_logit
=
VLLM_SHFL_SYNC
(
max_logit
,
0
);
// Size: 2 * num_partitions.
extern
__shared__
char
shared_mem
[];
// Load rescaled exp sums to shared memory.
// Workspace for reduction.
float
*
shared_exp_sums
=
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
const
float
*
exp_sums_ptr
=
exp_sums
+
// Load max logits to shared memory.
seq_idx
*
num_heads
*
max_num_partitions
+
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
head_idx
*
max_num_partitions
;
const
float
*
max_logits_ptr
=
max_logits
+
float
global_exp_sum
=
0.0
f
;
seq_idx
*
num_heads
*
max_num_partitions
+
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
head_idx
*
max_num_partitions
;
float
l
=
shared_max_logits
[
i
];
float
max_logit
=
-
FLT_MAX
;
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
global_exp_sum
+=
rescaled_exp_sum
;
const
float
l
=
max_logits_ptr
[
i
];
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
shared_max_logits
[
i
]
=
l
;
}
max_logit
=
fmaxf
(
max_logit
,
l
);
__syncthreads
();
}
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
__syncthreads
();
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Get the global max logit.
// Aggregate tmp_out to out.
// Reduce within the warp.
const
scalar_t
*
tmp_out_ptr
=
#pragma unroll
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
scalar_t
*
out_ptr
=
}
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
if
(
lane
==
0
)
{
#pragma unroll
red_smem
[
warp_idx
]
=
max_logit
;
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
}
float
acc
=
0.0
f
;
__syncthreads
();
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
// Reduce across warps.
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
inv_global_exp_sum
;
#pragma unroll
}
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
from_float
(
out_ptr
[
i
],
acc
);
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
}
}
// Broadcast the max value to all threads.
max_logit
=
VLLM_SHFL_SYNC
(
max_logit
,
0
);
}
// namespace vllm
// Load rescaled exp sums to shared memory.
#undef WARP_SIZE
float
*
shared_exp_sums
=
#undef MAX
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
#undef MIN
const
float
*
exp_sums_ptr
=
exp_sums
+
#undef DIVIDE_ROUND_UP
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
global_exp_sum
=
0.0
f
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
l
=
shared_max_logits
[
i
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
global_exp_sum
+=
rescaled_exp_sum
;
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
}
__syncthreads
();
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Aggregate tmp_out to out.
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
// namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/cache_kernels.cu
View file @
9e053941
...
@@ -728,4 +728,4 @@ void gather_cache(
...
@@ -728,4 +728,4 @@ void gather_cache(
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type width: "
,
dtype_bits
);
TORCH_CHECK
(
false
,
"Unsupported data type width: "
,
dtype_bits
);
}
}
}
}
\ No newline at end of file
csrc/layernorm_quant_kernels.cu
View file @
9e053941
...
@@ -5,238 +5,240 @@
...
@@ -5,238 +5,240 @@
* Currently, only static fp8 quantization is supported.
* Currently, only static fp8 quantization is supported.
*/
*/
#include "type_convert.cuh"
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#ifndef USE_ROCM
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
#endif
#include <torch/cuda.h>
#include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#ifndef USE_ROCM
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#else
#ifndef USE_ROCM
#include <hipcub/hipcub.hpp>
#include <cub/cub.cuh>
#endif
#else
#include <hipcub/hipcub.hpp>
namespace
vllm
{
#endif
// TODO(woosuk): Further optimize this kernel.
namespace
vllm
{
template
<
typename
scalar_t
,
typename
fp8_type
>
__global__
void
rms_norm_static_fp8_quant_kernel
(
// TODO(woosuk): Further optimize this kernel.
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
template
<
typename
scalar_t
,
typename
fp8_type
>
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
__global__
void
rms_norm_static_fp8_quant_kernel
(
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
float
*
__restrict__
scale
,
// [1]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
__shared__
float
s_variance
;
const
float
*
__restrict__
scale
,
// [1]
float
variance
=
0.0
f
;
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
variance
=
0.0
f
;
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
}
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
}
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
if
(
threadIdx
.
x
==
0
)
{
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
();
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
// invert scale to avoid division
__syncthreads
();
float
const
scale_inv
=
1.0
f
/
*
scale
;
// invert scale to avoid division
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
const
scale_inv
=
1.0
f
/
*
scale
;
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
}
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
/* Function specialization in the case of FP16/BF16 tensors.
}
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
/* Function specialization in the case of FP16/BF16 tensors.
memory latency bottleneck. */
Additional optimizations we can make in this case are
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
packed and vectorized operations, which help with the
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
memory latency bottleneck. */
fused_add_rms_norm_static_fp8_quant_kernel
(
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
fused_add_rms_norm_static_fp8_quant_kernel
(
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
float
*
__restrict__
scale
,
// [1]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
// Sanity checks on our vector struct and type-punned pointer arithmetic
const
float
*
__restrict__
scale
,
// [1]
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
const
int
vec_hidden_size
=
hidden_size
/
width
;
/* These and the argument pointers are all declared `restrict` as they are
__shared__
float
s_variance
;
not aliased in practice. Argument pointers should not be dereferenced
float
variance
=
0.0
f
;
in this kernel as that would be undefined behavior */
/* These and the argument pointers are all declared `restrict` as they are
auto
*
__restrict__
input_v
=
not aliased in practice. Argument pointers should not be dereferenced
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
in this kernel as that would be undefined behavior */
auto
*
__restrict__
residual_v
=
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
weight_v
=
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
temp
+=
residual_v
[
id
];
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
variance
+=
temp
.
sum_squares
();
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
residual_v
[
id
]
=
temp
;
temp
+=
residual_v
[
id
];
}
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
}
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
if
(
threadIdx
.
x
==
0
)
{
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
();
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
// invert scale to avoid division
__syncthreads
();
float
const
scale_inv
=
1.0
f
/
*
scale
;
// invert scale to avoid division
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
float
const
scale_inv
=
1.0
f
/
*
scale
;
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
temp
*=
s_variance
;
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
temp
*=
weight_v
[
idx
];
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
#pragma unroll
temp
*=
s_variance
;
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
temp
*=
weight_v
[
idx
];
out
[
id
*
width
+
i
]
=
#pragma unroll
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
}
out
[
id
*
width
+
i
]
=
}
scaled_fp8_conversion
<
true
,
fp8_type
>
(
float
(
temp
.
data
[
i
]),
scale_inv
);
}
}
}
/* Generic fused_add_rms_norm_kernel
}
The width field is not used here but necessary for other specializations.
*/
/* Generic fused_add_rms_norm_kernel
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
The width field is not used here but necessary for other specializations.
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
*/
fused_add_rms_norm_static_fp8_quant_kernel
(
template
<
typename
scalar_t
,
int
width
,
typename
fp8_type
>
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
fused_add_rms_norm_static_fp8_quant_kernel
(
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
fp8_type
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
float
*
__restrict__
scale
,
// [1]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
__shared__
float
s_variance
;
const
float
*
__restrict__
scale
,
// [1]
float
variance
=
0.0
f
;
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
variance
=
0.0
f
;
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
z
;
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
float
x
=
(
float
)
z
;
}
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
}
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
if
(
threadIdx
.
x
==
0
)
{
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
();
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
// invert scale to avoid division
__syncthreads
();
float
const
scale_inv
=
1.0
f
/
*
scale
;
// invert scale to avoid division
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
const
scale_inv
=
1.0
f
/
*
scale
;
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
float
const
out_norm
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
}
scaled_fp8_conversion
<
true
,
fp8_type
>
(
out_norm
,
scale_inv
);
}
}
// namespace vllm
}
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
}
// namespace vllm
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
scale
,
// [1]
torch
::
Tensor
&
input
,
// [..., hidden_size]
double
epsilon
)
{
torch
::
Tensor
&
weight
,
// [hidden_size]
int
hidden_size
=
input
.
size
(
-
1
);
torch
::
Tensor
&
scale
,
// [1]
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
double
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
dim3
grid
(
num_tokens
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
VLLM_DISPATCH_FLOATING_TYPES
(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FP8_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
input
.
scalar_type
(),
"rms_norm_kernel_scalar_type"
,
[
&
]
{
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
VLLM_DISPATCH_FP8_TYPES
(
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
scalar_type
(),
"rms_norm_kernel_fp8_type"
,
[
&
]
{
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
vllm
::
rms_norm_static_fp8_quant_kernel
<
scalar_t
,
fp8_t
>
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
epsilon
,
num_tokens
,
hidden_size
);
out
.
data_ptr
<
fp8_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
});
weight
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
});
epsilon
,
num_tokens
,
hidden_size
);
}
});
});
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
}
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FP8_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
VLLM_DISPATCH_FP8_TYPES( \
width, fp8_t> \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
<<<grid, block, 0, stream>>>( \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
width, fp8_t> \
residual.data_ptr<scalar_t>(), \
<<<grid, block, 0, stream>>>( \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
residual.data_ptr<scalar_t>(), \
}); \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
});
epsilon, num_tokens, hidden_size); \
void
fused_add_rms_norm_static_fp8_quant
(
}); \
torch
::
Tensor
&
out
,
// [..., hidden_size],
});
torch
::
Tensor
&
input
,
// [..., hidden_size]
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
out
,
// [..., hidden_size],
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
scale
,
// [1]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
double
epsilon
)
{
torch
::
Tensor
&
weight
,
// [hidden_size]
int
hidden_size
=
input
.
size
(
-
1
);
torch
::
Tensor
&
scale
,
// [1]
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
double
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
dim3
grid
(
num_tokens
);
for increased block occupancy on CUs and better latency
/* This kernel is memory-latency bound in many scenarios.
hiding on global mem ops. */
When num_tokens is large, a smaller block size allows
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
for increased block occupancy on CUs and better latency
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
hiding on global mem ops. */
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
with packed + vectorized ops.
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
Max optimization is achieved with a width-8 vector of FP16/BF16s
/*If the tensor types are FP16/BF16, try to use the optimized kernel
since we can load at most 128 bits at once in a global memory op.
with packed + vectorized ops.
However, this requires each tensor's data to be aligned to 16
Max optimization is achieved with a width-8 vector of FP16/BF16s
bytes.
since we can load at most 128 bits at once in a global memory op.
*/
However, this requires each tensor's data to be aligned to 16
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
bytes.
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
*/
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
bool
ptrs_are_aligned
=
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
bool
ptrs_are_aligned
=
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
}
else
{
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
}
else
{
}
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
\ No newline at end of file
csrc/ops.h
View file @
9e053941
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
...
@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
torch
::
Tensor
&
weight
,
double
epsilon
);
void
rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
//
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
scale
,
//
torch::Tensor& weight, torch::Tensor& scale,
double
epsilon
);
//
double epsilon);
void
fused_add_rms_norm_static_fp8_quant
(
torch
::
Tensor
&
out
,
//
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch
::
Tensor
&
input
,
//
torch::Tensor& input,
torch
::
Tensor
&
residual
,
//
torch::Tensor& residual,
torch
::
Tensor
&
weight
,
//
torch::Tensor& weight,
torch
::
Tensor
&
scale
,
double
epsilon
);
//
torch::Tensor& scale, double epsilon);
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
void
rms_norm_dynamic_per_token_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
...
@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
const
&
scale
);
//
torch::Tensor const& scale);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
//
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
&
scale
);
//
torch::Tensor& scale);
void
dynamic_per_token_scaled_fp8_quant
(
//
void dynamic_per_token_scaled_fp8_quant(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scale
,
//
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std
::
optional
<
torch
::
Tensor
>
const
&
scale_ub
);
//
std::optional<torch::Tensor> const& scale_ub);
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
void
selective_scan_fwd
(
const
torch
::
Tensor
&
u
,
const
torch
::
Tensor
&
delta
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
9e053941
#pragma once
#pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h>
#include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
...
@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
// namespace fp8
}
// namespace fp8
#endif // USE_ROCM
#endif // USE_ROCM
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
csrc/quantization/fused_kernels/quant_conversions.cuh
View file @
9e053941
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh"
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
//
#include "quantization/fp8/common.cuh"
namespace
vllm
{
namespace
vllm
{
...
...
csrc/quantization/gptq/compat.cuh
View file @
9e053941
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
...
@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
//
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half
*
address
,
half
val
)
{
//
__device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half
(
address
,
val
);
//
atomicAdd_half(address, val);
}
//
}
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
//
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__
__forceinline__
void
atomicAdd
(
half2
*
address
,
half2
val
)
{
//
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2
(
address
,
val
);
//
atomicAdd_half2(address, val);
}
//
}
#endif
//
#endif
#endif
//
#endif
#endif
//
#endif
}
// namespace gptq
}
// namespace gptq
}
// namespace vllm
}
// namespace vllm
...
...
csrc/torch_bindings.cpp
View file @
9e053941
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
//
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
//
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
//
"Tensor scale, float epsilon) -> "
"()"
);
//
"()");
ops
.
impl
(
"rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&
rms_norm_static_fp8_quant
);
//
&rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
// In-place fused Add and RMS Normalization.
ops
.
def
(
//
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
//
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
//
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"
);
//
"Tensor scale, float epsilon) -> ()");
ops
.
impl
(
"fused_add_rms_norm_static_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&
fused_add_rms_norm_static_fp8_quant
);
//
&fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
// Fused Layernorm + Quant kernels
ops
.
def
(
ops
.
def
(
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Compute FP8 quantized tensor for given scaling factor.
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
//
ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
//
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"
);
//
"()");
ops
.
impl
(
"static_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
static_scaled_fp8_quant
);
//
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
//
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops
.
def
(
//
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
//
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
//
"-> "
"()"
);
//
"()");
ops
.
impl
(
"dynamic_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_fp8_quant
);
//
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops
.
def
(
//
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
//
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
//
"Tensor! scale, Tensor? scale_ub) -> "
"()"
);
//
"()");
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
//
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&
dynamic_per_token_scaled_fp8_quant
);
//
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
ops
.
def
(
...
@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
...
@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
}
}
#endif
#endif
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
\ No newline at end of file
setup.py
View file @
9e053941
...
@@ -643,8 +643,8 @@ ext_modules = []
...
@@ -643,8 +643,8 @@ ext_modules = []
if
_is_cuda
()
or
_is_hip
():
if
_is_cuda
()
or
_is_hip
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._moe_C"
))
if
_is_hip
():
#
if _is_hip():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._rocm_C"
))
#
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if
_is_cuda
():
if
_is_cuda
():
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa2_C"
))
...
...
vllm/_custom_ops.py
View file @
9e053941
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
...
@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size
,
blocksparse_head_sliding_step
)
blocksparse_block_size
,
blocksparse_head_sliding_step
)
def
paged_attention_rocm
(
#
def paged_attention_rocm(
out
:
torch
.
Tensor
,
#
out: torch.Tensor,
exp_sum
:
torch
.
Tensor
,
#
exp_sum: torch.Tensor,
max_logits
:
torch
.
Tensor
,
#
max_logits: torch.Tensor,
tmp_out
:
torch
.
Tensor
,
#
tmp_out: torch.Tensor,
query
:
torch
.
Tensor
,
#
query: torch.Tensor,
key_cache
:
torch
.
Tensor
,
#
key_cache: torch.Tensor,
value_cache
:
torch
.
Tensor
,
#
value_cache: torch.Tensor,
num_kv_heads
:
int
,
#
num_kv_heads: int,
scale
:
float
,
#
scale: float,
block_tables
:
torch
.
Tensor
,
#
block_tables: torch.Tensor,
seq_lens
:
torch
.
Tensor
,
#
seq_lens: torch.Tensor,
block_size
:
int
,
#
block_size: int,
max_seq_len
:
int
,
#
max_seq_len: int,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
#
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype
:
str
,
#
kv_cache_dtype: str,
k_scale
:
torch
.
Tensor
,
#
k_scale: torch.Tensor,
v_scale
:
torch
.
Tensor
,
#
v_scale: torch.Tensor,
)
->
None
:
#
) -> None:
torch
.
ops
.
_rocm_C
.
paged_attention
(
out
,
exp_sum
,
max_logits
,
tmp_out
,
query
,
#
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache
,
value_cache
,
num_kv_heads
,
#
key_cache, value_cache, num_kv_heads,
scale
,
block_tables
,
seq_lens
,
#
scale, block_tables, seq_lens,
block_size
,
max_seq_len
,
alibi_slopes
,
#
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype
,
k_scale
,
v_scale
)
#
kv_cache_dtype, k_scale, v_scale)
# pos encoding ops
# pos encoding ops
...
@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
...
@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
\ No newline at end of file
vllm/attention/backends/rocm_flash_attn.py
View file @
9e053941
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
use_custom
=
_use_rocm_custom_paged_attention
(
# use_custom = _use_rocm_custom_paged_attention(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
# decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta
.
max_decode_seq_len
)
# decode_meta.max_decode_seq_len)
use_custom
=
False
if
use_custom
:
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
!=
AttentionType
.
ENCODER_DECODER
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment