Unverified Commit 5f6d10c1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)

parent 9b9a10d6
BasedOnStyle: Google
UseTab: Never
IndentWidth: 2
ColumnLimit: 80
# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left
# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false
# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror
\ No newline at end of file
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
namespace vllm { namespace vllm {
// Activation and gating kernel template. // Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
...@@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( ...@@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
} }
} }
template<typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x))); return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) { __device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation. // Equivalent to PyTorch GELU with 'none' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x; const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2; constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation. // Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x; const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715; constexpr float KAPPA = 0.044715;
float x_cube = f * f * f; float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube); float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \ });
input.data_ptr<scalar_t>(), \
d); \ void silu_and_mul(torch::Tensor& out, // [..., d]
}); torch::Tensor& input) // [..., 2 * d]
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
...@@ -96,11 +90,11 @@ void gelu_tanh_and_mul( ...@@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
...@@ -108,54 +102,49 @@ __global__ void activation_kernel( ...@@ -108,54 +102,49 @@ __global__ void activation_kernel(
} }
} }
} // namespace vllm } // namespace vllm
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \ int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ });
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm { namespace vllm {
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) { __device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x); const float x3 = (float)(x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
return ((T) 0.5) * x * (((T) 1.0) + t); (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -22,31 +23,31 @@ ...@@ -22,31 +23,31 @@
namespace vllm { namespace vllm {
// A vector type to store Q, K, V elements. // A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE> template <typename T, int VEC_SIZE>
struct Vec {}; struct Vec {};
// A vector type to store FP32 accumulators. // A vector type to store FP32 accumulators.
template<typename T> template <typename T>
struct FloatVec {}; struct FloatVec {};
// Template vector operations. // Template vector operations.
template<typename Acc, typename A, typename B> template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b); inline __device__ Acc mul(A a, B b);
template<typename T> template <typename T>
inline __device__ float sum(T v); inline __device__ float sum(T v);
template<typename T> template <typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b)); return sum(mul<T, T, T>(a, b));
} }
template<typename A, typename T> template <typename A, typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b)); return sum(mul<A, T, T>(a, b));
} }
template<typename T> template <typename T>
inline __device__ void zero(T& dst) { inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4; constexpr int WORDS = sizeof(T) / 4;
union { union {
...@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { ...@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw; dst = tmp.raw;
} }
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -27,15 +28,15 @@ ...@@ -27,15 +28,15 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
...@@ -45,7 +46,7 @@ ...@@ -45,7 +46,7 @@
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
template<int NUM_WARPS> template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) { inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane. // Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE; int warp = threadIdx.x / WARP_SIZE;
...@@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, int PARTITION_SIZE = 0> // Zero means no partitioning.
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] // head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size/x, block_size, x]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int* __restrict__ seq_lens, // [num_seqs] // head_size, block_size]
const int max_num_blocks_per_seq, const int num_kv_heads, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads] const float scale,
const int q_stride, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int kv_block_stride, const int* __restrict__ seq_lens, // [num_seqs]
const int kv_head_stride, const int max_num_blocks_per_seq,
const float kv_scale) { const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
...@@ -118,22 +116,29 @@ __device__ void paged_attention_kernel( ...@@ -118,22 +116,29 @@ __device__ void paged_attention_kernel(
} }
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int start_block_idx =
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
...@@ -143,13 +148,14 @@ __device__ void paged_attention_kernel( ...@@ -143,13 +148,14 @@ __device__ void paged_attention_kernel(
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread
// fetch or compute 16 bytes at a time. // group fetch or compute 16 bytes at a time. For example, if the size of a
// For example, if the size of a thread group is 4 and the data type is half, // thread group is 4 and the data type is half, then the vector size is 16 /
// then the vector size is 16 / (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
...@@ -163,18 +169,21 @@ __device__ void paged_attention_kernel( ...@@ -163,18 +169,21 @@ __device__ void paged_attention_kernel(
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// th vectors of the query, and so on. // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -193,44 +202,50 @@ __device__ void paged_attention_kernel( ...@@ -193,44 +202,50 @@ __device__ void paged_attention_kernel(
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride const cache_t* k_ptr =
+ kv_head_idx * kv_head_stride k_cache + physical_block_number * kv_block_stride +
+ physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(k_vec_quant, kv_scale); k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
} }
} }
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
...@@ -285,13 +300,12 @@ __device__ void paged_attention_kernel( ...@@ -285,13 +300,12 @@ __device__ void paged_attention_kernel(
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions seq_idx * num_heads * max_num_partitions +
+ partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions head_idx * max_num_partitions + partition_idx;
+ partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum;
} }
...@@ -304,7 +318,8 @@ __device__ void paged_attention_kernel( ...@@ -304,7 +318,8 @@ __device__ void paged_attention_kernel(
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
...@@ -315,18 +330,21 @@ __device__ void paged_attention_kernel( ...@@ -315,18 +330,21 @@ __device__ void paged_attention_kernel(
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
+ kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
...@@ -337,14 +355,17 @@ __device__ void paged_attention_kernel( ...@@ -337,14 +355,17 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset); V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, kv_scale); v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the
// we should explicitly zero out the values since they may contain NaNs. // context, we should explicitly zero out the values since they may
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
...@@ -367,8 +388,8 @@ __device__ void paged_attention_kernel( ...@@ -367,8 +388,8 @@ __device__ void paged_attention_kernel(
accs[i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for logits // NOTE(woosuk): A barrier is required because the shared memory space for
// is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
...@@ -405,9 +426,9 @@ __device__ void paged_attention_kernel( ...@@ -405,9 +426,9 @@ __device__ void paged_attention_kernel(
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE scalar_t* out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
...@@ -419,79 +440,75 @@ __device__ void paged_attention_kernel( ...@@ -419,79 +440,75 @@ __device__ void paged_attention_kernel(
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS,
typename cache_t, vllm::Fp8KVCacheDataType KV_DTYPE>
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size, block_size]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int num_kv_heads, // [num_heads]
const int* __restrict__ seq_lens, // [num_seqs] const float scale,
const int max_num_blocks_per_seq, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads] const int* __restrict__ seq_lens, // [num_seqs]
const int q_stride, const int max_num_blocks_per_seq,
const int kv_block_stride, const float* __restrict__ alibi_slopes, // [num_heads]
const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) { const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>( paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
/* exp_sums */ nullptr, /* max_logits */ nullptr, KV_DTYPE>(
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, int PARTITION_SIZE>
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
vllm::Fp8KVCacheDataType KV_DTYPE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] // max_num_partitions, head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size/x, block_size, x]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int* __restrict__ seq_lens, // [num_seqs] // head_size, block_size]
const int max_num_blocks_per_seq, const int num_kv_heads, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads] const float scale,
const int q_stride, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int kv_block_stride, const int* __restrict__ seq_lens, // [num_seqs]
const int kv_head_stride, const int max_num_blocks_per_seq,
const float kv_scale) { const float* __restrict__ alibi_slopes, // [num_heads]
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, PARTITION_SIZE>( const int q_stride, const int kv_block_stride, const int kv_head_stride,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, const float kv_scale) {
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
q_stride, kv_block_stride, kv_head_stride, kv_scale); KV_DTYPE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template< template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
typename scalar_t, int PARTITION_SIZE>
int HEAD_SIZE,
int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel( __global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads,
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const float* __restrict__ max_logits, // [num_seqs, num_heads,
const int* __restrict__ seq_lens, // [num_seqs] // max_num_partitions]
const int max_num_partitions) { const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
...@@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel(
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; scalar_t* out_ptr =
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ head_idx * max_num_partitions * HEAD_SIZE; const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i]; out_ptr[i] = tmp_out_ptr[i];
} }
...@@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel(
// Load max logits to shared memory. // Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions const float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions; seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX; float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i]; const float l = max_logits_ptr[i];
...@@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = VLLM_SHFL_SYNC(max_logit, 0); max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory. // Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); float* shared_exp_sums =
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+ head_idx * max_num_partitions; const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f; float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i]; float l = shared_max_logits[i];
...@@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel(
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out. // Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE const scalar_t* tmp_out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE; tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f; float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) { for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
} }
from_float(out_ptr[i], acc); from_float(out_ptr[i], acc);
} }
} }
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ ((void*)vllm::paged_attention_v1_kernel< \
KV_DTYPE>), shared_mem_size); \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ shared_mem_size); \
KV_DTYPE><<<grid, block, shared_mem_size, stream>>>( \ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
out_ptr, \ NUM_THREADS, KV_DTYPE> \
query_ptr, \ <<<grid, block, shared_mem_size, stream>>>( \
key_cache_ptr, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
value_cache_ptr, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
num_kv_heads, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
scale, \ kv_scale);
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template <typename T, typename CACHE_T, int BLOCK_SIZE,
typename T, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
typename CACHE_T,
int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& value_cache, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -632,9 +638,10 @@ void paged_attention_v1_launcher( ...@@ -632,9 +638,10 @@ void paged_attention_v1_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
: nullptr; ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
...@@ -644,7 +651,8 @@ void paged_attention_v1_launcher( ...@@ -644,7 +651,8 @@ void paged_attention_v1_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float); int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
...@@ -683,19 +691,10 @@ void paged_attention_v1_launcher( ...@@ -683,19 +691,10 @@ void paged_attention_v1_launcher(
} }
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
query, \ seq_lens, max_seq_len, alibi_slopes, kv_scale);
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
...@@ -716,74 +715,45 @@ void paged_attention_v1_launcher( ...@@ -716,74 +715,45 @@ void paged_attention_v1_launcher(
} }
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
int num_kv_heads, // [num_heads] torch::Tensor&
float scale, value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] int num_kv_heads, // [num_heads]
torch::Tensor& seq_lens, // [num_seqs] float scale,
int block_size, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
int max_seq_len, torch::Tensor& seq_lens, // [num_seqs]
const c10::optional<torch::Tensor>& alibi_slopes, int block_size, int max_seq_len,
const std::string& kv_cache_dtype, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { const std::string& kv_cache_dtype, float kv_scale){
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
} CALL_V1_LAUNCHER_BLOCK_SIZE)}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
KV_DTYPE, PARTITION_SIZE> \ <<<grid, block, shared_mem_size, stream>>>( \
<<<grid, block, shared_mem_size, stream>>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
exp_sums_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
max_logits_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
tmp_out_ptr, \ kv_block_stride, kv_head_stride, kv_scale); \
query_ptr, \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
key_cache_ptr, \ PARTITION_SIZE> \
value_cache_ptr, \ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
num_kv_heads, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
scale, \ max_num_partitions);
block_tables_ptr, \
seq_lens_ptr, \ template <typename T, typename CACHE_T, int BLOCK_SIZE,
max_num_blocks_per_seq, \ vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
alibi_slopes_ptr, \ int PARTITION_SIZE = 512>
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
max_num_partitions);
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher( void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& query, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -796,9 +766,10 @@ void paged_attention_v2_launcher( ...@@ -796,9 +766,10 @@ void paged_attention_v2_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
: nullptr; ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
...@@ -853,59 +824,50 @@ void paged_attention_v2_launcher( ...@@ -853,59 +824,50 @@ void paged_attention_v2_launcher(
} }
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
out, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
exp_sums, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
max_logits, \ kv_scale);
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
break; \ break; \
case 16: \ case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \ case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor&
torch::Tensor& query, // [num_seqs, num_heads, head_size] tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor&
int num_kv_heads, // [num_heads] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
float scale, torch::Tensor&
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& seq_lens, // [num_seqs] int num_kv_heads, // [num_heads]
int block_size, float scale,
int max_seq_len, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& seq_lens, // [num_seqs]
const std::string& kv_cache_dtype, int block_size, int max_seq_len,
float kv_scale) { const c10::optional<torch::Tensor>& alibi_slopes,
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) const std::string& kv_cache_dtype, float kv_scale) {
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
} }
#undef WARP_SIZE #undef WARP_SIZE
......
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -26,7 +27,7 @@ ...@@ -26,7 +27,7 @@
namespace vllm { namespace vllm {
// Q*K^T operation. // Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately). // Compute the parallel products for Q*K^T (treat vector lanes separately).
...@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template<typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template<typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
}; };
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -28,8 +30,8 @@ ...@@ -28,8 +30,8 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#include <stdint.h> #include <stdint.h>
...@@ -50,37 +52,37 @@ struct bf16_8_t { ...@@ -50,37 +52,37 @@ struct bf16_8_t {
}; };
// BF16 vector types for Q, K, V. // BF16 vector types for Q, K, V.
template<> template <>
struct Vec<__nv_bfloat16, 1> { struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16; using Type = __nv_bfloat16;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 2> { struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162; using Type = __nv_bfloat162;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 4> { struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t; using Type = bf16_4_t;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 8> { struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t; using Type = bf16_8_t;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<__nv_bfloat16> { struct FloatVec<__nv_bfloat16> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<__nv_bfloat162> { struct FloatVec<__nv_bfloat162> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<bf16_4_t> { struct FloatVec<bf16_4_t> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<bf16_8_t> { struct FloatVec<bf16_8_t> {
using Type = Float8_; using Type = Float8_;
}; };
...@@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
assert(false); assert(false);
#else #else
#ifndef USE_ROCM #ifndef USE_ROCM
return a + b; return a + b;
#else #else
return __hadd(a, b); return __hadd(a, b);
#endif #endif
#endif #endif
} }
...@@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { ...@@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
...@@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#endif #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
...@@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { ...@@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#endif #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c; bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { ...@@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_4_t c; bf16_4_t c;
...@@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { ...@@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c; bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { ...@@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_8_t c; bf16_8_t c;
...@@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a); float fa = __bfloat162float(a);
float fb = __bfloat162float(b); float fb = __bfloat162float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a); float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b); float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { ...@@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float4_ fc; Float4_ fc;
...@@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { ...@@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { ...@@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float8_ fc; Float8_ fc;
...@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
...@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf ...@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
...@@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { ...@@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(__nv_bfloat16 v) { inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v); return __bfloat162float(v);
} }
template<> template <>
inline __device__ float sum(__nv_bfloat162 v) { inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v); float2 vf = bf1622float2(v);
return vf.x + vf.y; return vf.x + vf.y;
} }
template<> template <>
inline __device__ float sum(bf16_4_t v) { inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y); return sum(v.x) + sum(v.y);
} }
template<> template <>
inline __device__ float sum(bf16_8_t v) { inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
} }
...@@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { ...@@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
#endif #endif
} }
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -30,37 +32,37 @@ ...@@ -30,37 +32,37 @@
namespace vllm { namespace vllm {
// FP16 vector types for Q, K, V. // FP16 vector types for Q, K, V.
template<> template <>
struct Vec<uint16_t, 1> { struct Vec<uint16_t, 1> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint16_t, 2> { struct Vec<uint16_t, 2> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint16_t, 4> { struct Vec<uint16_t, 4> {
using Type = uint2; using Type = uint2;
}; };
template<> template <>
struct Vec<uint16_t, 8> { struct Vec<uint16_t, 8> {
using Type = uint4; using Type = uint4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<uint16_t> { struct FloatVec<uint16_t> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<uint32_t> { struct FloatVec<uint32_t> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<uint2> { struct FloatVec<uint2> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<uint4> { struct FloatVec<uint4> {
using Type = Float8_; using Type = Float8_;
}; };
...@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { ...@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
return b; return b;
#else #else
union { union {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
tmp.u16[0] = a; tmp.u16[0] = a;
tmp.u16[1] = a; tmp.u16[1] = a;
...@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { ...@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif #endif
#else #else
tmp.u16[0] = float_to_half(f.x); tmp.u16[0] = float_to_half(f.x);
...@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { ...@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) { inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { ...@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) { inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { ...@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) { inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b); return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ uint2 mul(uint2 a, uint2 b) { inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c; uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
...@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { ...@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint2 mul(uint16_t a, uint2 b) { inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint2 c; uint2 c;
...@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { ...@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint4 a, uint4 b) { inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c; uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
...@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { ...@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint16_t a, uint4 b) { inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint4 c; uint4 c;
...@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { ...@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(uint16_t a, uint16_t b) { inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a); float fa = half_to_float(a);
float fb = half_to_float(b); float fb = half_to_float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) { inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a); float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b); float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(uint16_t a, uint32_t b) { inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b); return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(uint2 a, uint2 b) { inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
...@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { ...@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(uint16_t a, uint2 b) { inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float4_ fc; Float4_ fc;
...@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { ...@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint4 a, uint4 b) { inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
...@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { ...@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float8_ fc; Float8_ fc;
...@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { ...@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
...@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { ...@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(uint16_t v) { inline __device__ float sum(uint16_t v) {
return half_to_float(v); return half_to_float(v);
} }
template<> template <>
inline __device__ float sum(uint32_t v) { inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v); float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y; return tmp.x + tmp.y;
} }
template<> template <>
inline __device__ float sum(uint2 v) { inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
return sum(c); return sum(c);
} }
template<> template <>
inline __device__ float sum(uint4 v) { inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
c = add(c, v.z); c = add(c, v.z);
...@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { ...@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
...@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { ...@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -38,37 +40,35 @@ struct Float8_ { ...@@ -38,37 +40,35 @@ struct Float8_ {
}; };
// FP32 vector types for Q, K, V. // FP32 vector types for Q, K, V.
template<> template <>
struct Vec<float, 1> { struct Vec<float, 1> {
using Type = float; using Type = float;
}; };
template<> template <>
struct Vec<float, 2> { struct Vec<float, 2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct Vec<float, 4> { struct Vec<float, 4> {
using Type = float4; using Type = float4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<float> { struct FloatVec<float> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<float2> { struct FloatVec<float2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<float4> { struct FloatVec<float4> {
using Type = float4; using Type = float4;
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
...@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { ...@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ float mul<float, float>(float a, float b) { inline __device__ float mul<float, float>(float a, float b) {
return a * b; return a * b;
} }
template<> template <>
inline __device__ float2 mul(float2 a, float2 b) { inline __device__ float2 mul(float2 a, float2 b) {
float2 c; float2 c;
c.x = a.x * b.x; c.x = a.x * b.x;
...@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { ...@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float2 mul(float a, float2 b) { inline __device__ float2 mul(float a, float2 b) {
float2 c; float2 c;
c.x = a * b.x; c.x = a * b.x;
...@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { ...@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float4 a, float4 b) { inline __device__ float4 mul(float4 a, float4 b) {
float4 c; float4 c;
c.x = a.x * b.x; c.x = a.x * b.x;
...@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { ...@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float a, float4 b) { inline __device__ float4 mul(float a, float4 b) {
float4 c; float4 c;
c.x = a * b.x; c.x = a * b.x;
...@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { ...@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
...@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { ...@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(float v) { inline __device__ float sum(float v) {
return v; return v;
} }
template<> template <>
inline __device__ float sum(float2 v) { inline __device__ float sum(float2 v) {
return v.x + v.y; return v.x + v.y;
} }
template<> template <>
inline __device__ float sum(float4 v) { inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w; return v.x + v.y + v.z + v.w;
} }
template<> template <>
inline __device__ float sum(Float4_ v) { inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y; return v.x.x + v.x.y + v.y.x + v.y.y;
} }
template<> template <>
inline __device__ float sum(Float8_ v) { inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
...@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { ...@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm
...@@ -4,38 +4,38 @@ ...@@ -4,38 +4,38 @@
#include <stdint.h> #include <stdint.h>
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cuda_fp8.h> #include <cuda_fp8.h>
#endif // USE_ROCM #endif // USE_ROCM
#endif // ENABLE_FP8 #endif // ENABLE_FP8
namespace vllm { namespace vllm {
enum class Fp8KVCacheDataType { enum class Fp8KVCacheDataType {
kAuto = 0, kAuto = 0,
kFp8E4M3 = 1, kFp8E4M3 = 1,
kFp8E5M2 = 2, kFp8E5M2 = 2,
}; };
// fp8 vector types for quantization of kv cache // fp8 vector types for quantization of kv cache
template<> template <>
struct Vec<uint8_t, 1> { struct Vec<uint8_t, 1> {
using Type = uint8_t; using Type = uint8_t;
}; };
template<> template <>
struct Vec<uint8_t, 2> { struct Vec<uint8_t, 2> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint8_t, 4> { struct Vec<uint8_t, 4> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint8_t, 8> { struct Vec<uint8_t, 8> {
using Type = uint2; using Type = uint2;
}; };
} // namespace vllm } // namespace vllm
...@@ -5,36 +5,24 @@ ...@@ -5,36 +5,24 @@
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping);
torch::Tensor& dst,
const torch::Tensor& block_mapping);
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping);
const torch::Tensor& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value, torch::Tensor& slot_mapping,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, const float kv_scale);
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const float kv_scale);
void reshape_and_cache_flash( void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache,
torch::Tensor& value, torch::Tensor& value_cache,
torch::Tensor& key_cache, torch::Tensor& slot_mapping,
torch::Tensor& value_cache, const std::string& kv_cache_dtype);
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache, const float scale, const std::string& kv_cache_dtype);
torch::Tensor& src_cache,
const float scale,
const std::string& kv_cache_dtype);
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
#include "dispatch_utils.h" #include "dispatch_utils.h"
#ifdef USE_ROCM #ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh" #include "quantization/fp8/amd/quant_utils.cuh"
#else #else
#include "quantization/fp8/nvidia/quant_utils.cuh" #include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include <algorithm> #include <algorithm>
...@@ -18,20 +18,17 @@ ...@@ -18,20 +18,17 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping) {
torch::Tensor& dst,
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) { if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice; memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) { } else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost; memcpy_type = cudaMemcpyDeviceToHost;
...@@ -41,16 +38,17 @@ void swap_blocks( ...@@ -41,16 +38,17 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination"); TORCH_CHECK(false, "Invalid device combination");
} }
// NOTE(youkaichao): keep in mind that `block_mapping` should be // NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu // a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization. // synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char *src_ptr = static_cast<char*>(src.data_ptr()); char* src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
const int64_t num_blocks = block_mapping.size(0); const int64_t num_blocks = block_mapping.size(0);
...@@ -59,29 +57,25 @@ void swap_blocks( ...@@ -59,29 +57,25 @@ void swap_blocks(
int64_t dst_block_number = block_mapping[i][1].item<int64_t>(); int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
dst_ptr + dst_offset, block_size_in_bytes, memcpy_type, stream);
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
} }
} }
namespace vllm { namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)
template<typename scalar_t> template <typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* key_cache_ptrs, int64_t* value_cache_ptrs,
int64_t* value_cache_ptrs, const int64_t* __restrict__ block_mapping,
const int64_t* __restrict__ block_mapping, const int numel_per_block) {
const int numel_per_block) {
const int layer_idx = blockIdx.x; const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
...@@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel( ...@@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel(
} }
} }
} // namespace vllm } // namespace vllm
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping) {
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
...@@ -118,8 +111,10 @@ void copy_blocks( ...@@ -118,8 +111,10 @@ void copy_blocks(
int64_t key_cache_ptrs[num_layers]; int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] =
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// block_mapping is a 2D tensor with shape (num_pairs, 2). // block_mapping is a 2D tensor with shape (num_pairs, 2).
...@@ -127,10 +122,12 @@ void copy_blocks( ...@@ -127,10 +122,12 @@ void copy_blocks(
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob( torch::Tensor key_cache_ptrs_tensor =
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( .to(cache_device);
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
...@@ -139,31 +136,28 @@ void copy_blocks( ...@@ -139,31 +136,28 @@ void copy_blocks(
const at::cuda::OptionalCUDAGuard device_guard(cache_device); const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(), numel_per_block);
numel_per_block); }));
}));
} }
namespace vllm { namespace vllm {
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] // block_size, x]
const int64_t* __restrict__ slot_mapping, // [num_tokens] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
const int key_stride, // block_size]
const int value_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int head_size, const int block_size, const int x,
const int block_size, const float kv_scale) {
const int x,
const float kv_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
...@@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel( ...@@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx =
+ head_idx * (head_size / x) * block_size * x block_idx * num_heads * (head_size / x) * block_size * x +
+ x_idx * block_size * x head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x block_offset * x + x_offset;
+ x_offset; const int64_t tgt_value_idx =
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size head_idx * head_size * block_size + head_offset * block_size +
+ head_offset * block_size block_offset;
+ block_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale); key_cache[tgt_key_idx] =
value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
} }
} }
} }
template<typename scalar_t> template <typename scalar_t>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] // head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
const int block_stride, // head_size]
const int key_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int value_stride, const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int num_heads, const int head_size, const int block_size) {
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
...@@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t src_value_idx = token_idx * value_stride + i; const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride const int64_t tgt_value_idx = block_idx * block_stride +
+ block_offset * num_heads * head_size block_offset * num_heads * head_size +
+ head_idx * head_size head_idx * head_size + head_offset;
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx]; k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx]; v_cache[tgt_value_idx] = value[src_value_idx];
} }
} }
} // namespace vllm } // namespace vllm
// KV_T is the stored data type of kv-cache. // KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors. // CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
reinterpret_cast<KV_T*>(key.data_ptr()), \ <<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(value.data_ptr()), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
key_stride, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
value_stride, \ num_heads, head_size, block_size, x, kv_scale);
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor&
const std::string& kv_cache_dtype, value_cache, // [num_blocks, num_heads, head_size, block_size]
const float kv_scale) torch::Tensor& slot_mapping, // [num_tokens]
{ const std::string& kv_cache_dtype, const float kv_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
...@@ -283,17 +270,17 @@ void reshape_and_cache( ...@@ -283,17 +270,17 @@ void reshape_and_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE)
} }
void reshape_and_cache_flash( void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) const std::string& kv_cache_dtype) {
{
// FIXME: only support auto datatype, does not support fp8 // FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") { if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
...@@ -313,62 +300,47 @@ void reshape_and_cache_flash( ...@@ -313,62 +300,47 @@ void reshape_and_cache_flash(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), key.scalar_type(), "reshape_and_cache_flash", [&] {
"reshape_and_cache_flash", vllm::reshape_and_cache_flash_kernel<scalar_t>
[&] { <<<grid, block, 0, stream>>>(
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
k_cache.data_ptr<scalar_t>(), value_stride, num_heads, head_size, block_size);
v_cache.data_ptr<scalar_t>(), });
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
} }
namespace vllm { namespace vllm {
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel( __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache,
Tout* __restrict__ dst_cache, const float kv_scale,
const float kv_scale, const int64_t block_stride) {
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale); dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \ reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
kv_scale, \
block_stride);
// Only for testing. // Only for testing.
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache, const float kv_scale, const std::string& kv_cache_dtype) {
torch::Tensor& src_cache,
const float kv_scale,
const std::string& kv_cache_dtype)
{
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device); at::cuda::OptionalCUDAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0); int64_t num_blocks = src_cache.size(0);
...@@ -398,13 +370,15 @@ void convert_fp8( ...@@ -398,13 +370,15 @@ void convert_fp8(
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) { } else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) { } else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
......
#include "cpu_types.hpp" #include "cpu_types.hpp"
namespace { namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &), template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated> bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t *__restrict__ output) { scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
...@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, ...@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
} }
} }
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp()); return x / (ones + (zeros - x).exp());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
...@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { ...@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
...@@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { ...@@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er()); return x * w2 * (ones + (x * w1).er());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
...@@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { ...@@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh()); return x * w2 * (ones + inner.tanh());
} }
}; // namespace }; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) activation_kernel<scalar_t, silu_act, true>(
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
} }
void gelu_and_mul(torch::Tensor &out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) activation_kernel<scalar_t, gelu_act, true>(
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
} }
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
...@@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] ...@@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
}); });
} }
void gelu_new(torch::Tensor &out, torch::Tensor &input) { void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
...@@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { ...@@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
}); });
} }
void gelu_fast(torch::Tensor &out, torch::Tensor &input) { void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
namespace { namespace {
template <typename scalar_t> struct KernelVecType { template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void; using q_load_vec_type = void;
using q_vec_type = void; using q_vec_type = void;
using k_load_vec_type = void; using k_load_vec_type = void;
...@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType { ...@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void; using v_load_vec_type = void;
}; };
template <> struct KernelVecType<float> { template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4; using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
...@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> { ...@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
}; };
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32; using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32;
...@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> { ...@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#else #else
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16; using k_load_vec_type = vec_op::BF16Vec16;
...@@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> { ...@@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
#endif #endif
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
const int capacity) { const int capacity) {
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
...@@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, ...@@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
} }
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
reduceSoftmaxAlibi(T *data, const int size, const int capacity, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope,
const int seq_len) { const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1); data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
...@@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, ...@@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
} }
template <typename T> template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
const int size) { const int size) {
T max = max_data[0]; T max = max_data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
...@@ -132,9 +137,9 @@ struct reduceQKBlockKernel { ...@@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q, FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t *__restrict__ k_block, const scalar_t* __restrict__ k_block,
float *__restrict__ logits, float scale, float* __restrict__ logits, float scale,
const int token_num) { const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
...@@ -196,8 +201,8 @@ struct reduceQKBlockKernel { ...@@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t> int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t &&acc) { acc_t&& acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM); static_assert(BLOCK_SIZE == ELEM_NUM);
...@@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, ...@@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
}); });
} }
}; // namespace }; // namespace
// Paged attention v1 // Paged attention v1
namespace { namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl { struct paged_attention_v1_impl {
static void static void call(
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) { const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
...@@ -243,32 +248,31 @@ struct paged_attention_v1_impl { ...@@ -243,32 +248,31 @@ struct paged_attention_v1_impl {
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc( float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int *seq_block_table = const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
seq_len - (block_num - 1) * BLOCK_SIZE; float* __restrict__ thread_block_logits =
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
...@@ -282,8 +286,7 @@ struct paged_attention_v1_impl { ...@@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, seq_len, reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
block_num * BLOCK_SIZE);
} }
// Compute value // Compute value
...@@ -293,14 +296,14 @@ struct paged_attention_v1_impl { ...@@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition; head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -311,7 +314,7 @@ struct paged_attention_v1_impl { ...@@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -340,16 +343,16 @@ struct paged_attention_v1_impl { ...@@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
template <typename T, int BLOCK_SIZE> template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher( ...@@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher(
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
} }
} }
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \ CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens,
torch::Tensor &seq_lens, int block_size, int block_size, int max_seq_len,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor> &alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) {
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] { [&] {
...@@ -434,23 +436,24 @@ namespace { ...@@ -434,23 +436,24 @@ namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl { struct paged_attention_v2_impl {
static void call( static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads,
float // max_num_partitions]
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions]
// max_num_partitions, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] // max_num_partitions, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x]
// head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) { const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
...@@ -468,8 +471,7 @@ struct paged_attention_v2_impl { ...@@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len) if (start_token_idx >= seq_len) continue;
continue;
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
...@@ -477,15 +479,14 @@ struct paged_attention_v2_impl { ...@@ -477,15 +479,14 @@ struct paged_attention_v2_impl {
const int token_num = const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables + const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
...@@ -493,10 +494,10 @@ struct paged_attention_v2_impl { ...@@ -493,10 +494,10 @@ struct paged_attention_v2_impl {
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
...@@ -510,13 +511,13 @@ struct paged_attention_v2_impl { ...@@ -510,13 +511,13 @@ struct paged_attention_v2_impl {
logits, token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, token_num, max_and_sum =
block_num * BLOCK_SIZE); reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
} }
auto &&[max_logit, exp_sum] = max_and_sum; auto&& [max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr; scalar_t* __restrict__ output_buffer = nullptr;
if (!no_reduce) { if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions + auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
...@@ -538,13 +539,13 @@ struct paged_attention_v2_impl { ...@@ -538,13 +539,13 @@ struct paged_attention_v2_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition; output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -555,7 +556,7 @@ struct paged_attention_v2_impl { ...@@ -555,7 +556,7 @@ struct paged_attention_v2_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -587,8 +588,7 @@ struct paged_attention_v2_impl { ...@@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
reducePartitonSoftmax( reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions + max_logits + seq_idx * num_heads * max_num_partitions +
...@@ -603,11 +603,11 @@ struct paged_attention_v2_impl { ...@@ -603,11 +603,11 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group = constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE 16; // Note: didn't align with the cacheline size, due to some
// didn't align with 64 bytes // HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0); static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums; const float* __restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1) #pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
...@@ -616,17 +616,16 @@ struct paged_attention_v2_impl { ...@@ -616,17 +616,16 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
const float *__restrict__ seq_head_rescale_factors = const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions + rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions; head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out = const scalar_t* __restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output = scalar_t* __restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
...@@ -645,21 +644,21 @@ struct paged_attention_v2_impl { ...@@ -645,21 +644,21 @@ struct paged_attention_v2_impl {
} }
}; };
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512> template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher( void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher( ...@@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher(
int max_num_partitions = exp_sums.size(-1); int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default: default:
TORCH_CHECK(false, "Unsupported head size: ", head_size); TORCH_CHECK(false, "Unsupported head size: ", head_size);
break; break;
} }
} }
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
max_seq_len, alibi_slopes); alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \ CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables, float scale, torch::Tensor& block_tables,
torch::Tensor &seq_lens, int block_size, torch::Tensor& seq_lens, int block_size,
int max_seq_len, int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] { [&] {
......
...@@ -5,25 +5,26 @@ ...@@ -5,25 +5,26 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor> &value_caches, const torch::Tensor& mapping_pairs,
const torch::Tensor& mapping_pairs, const int element_num_per_block,
const int element_num_per_block, const int layer_num) { const int layer_num) {
const size_t pair_num = mapping_pairs.size(0); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>(); int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>(); element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset; scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>(); scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset; source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset; target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
...@@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( ...@@ -33,9 +34,9 @@ void copy_blocks_cpu_impl(
template <typename scalar_t> template <typename scalar_t>
void reshape_and_cache_cpu_impl( void reshape_and_cache_cpu_impl(
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens, const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) { const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size; const int block_elem_num = num_heads * head_size * block_size;
...@@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( ...@@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl(
int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx = int src_value_head_idx =
token_idx * value_stride + head_idx * head_size; token_idx * value_stride + head_idx * head_size;
const scalar_t *src_key_head_ptr = key + src_key_head_idx; const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx; const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size; const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size; const int64_t block_offset = slot_idx % block_size;
scalar_t *target_key_head_ptr = key_cache + scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
scalar_t *target_value_head_ptr = value_cache + scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
...@@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl( ...@@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl(
} }
} }
} }
}; // namespace }; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches, void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size(); unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
...@@ -100,10 +101,10 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches, ...@@ -100,10 +101,10 @@ void copy_blocks(std::vector<torch::Tensor> &key_caches,
}); });
} }
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor &slot_mapping, torch::Tensor& slot_mapping,
const std::string &kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
...@@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, ...@@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
}); });
} }
void swap_blocks(torch::Tensor &src, torch::Tensor &dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor&block_mapping) { const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
} }
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out, void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t *__restrict__ input, const scalar_t* __restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon, const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) { const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
...@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, ...@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
} }
template <typename scalar_t> template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input, void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t *__restrict__ residual, scalar_t* __restrict__ residual,
const scalar_t *__restrict__ weight, const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens, const float epsilon, const int num_tokens,
const int hidden_size) { const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
...@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, ...@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} }
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl) CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size); hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl) CPU_KERNEL_GUARD_OUT(rms_norm_impl)
}); });
} }
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor &weight, float epsilon) { torch::Tensor& weight, float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
...@@ -4,16 +4,16 @@ ...@@ -4,16 +4,16 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
...@@ -26,7 +26,7 @@ void rotary_embedding_impl( ...@@ -26,7 +26,7 @@ void rotary_embedding_impl(
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
const int head_idx = i; const int head_idx = i;
...@@ -94,16 +94,16 @@ void rotary_embedding_impl( ...@@ -94,16 +94,16 @@ void rotary_embedding_impl(
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
...@@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl( ...@@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query; scalar_t* head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
...@@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl( ...@@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head; scalar_t* head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
...@@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl( ...@@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl(
} }
} }
} }
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor &key, int head_size, torch::Tensor& key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
......
...@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul, "Activation function used in GeGLU with `none` approximation.");
"Activation function used in SwiGLU."); ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
ops.def( "Activation function used in GeGLU with `tanh` approximation.");
"gelu_and_mul", ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
&gelu_and_mul, ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm", "Apply Root Mean Square (RMS) Normalization to the input tensor.");
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm", "In-place fused Add and RMS Normalization");
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks", "Swap in (out) the cache blocks from src to dst");
&swap_blocks, cache_ops.def("copy_blocks", &copy_blocks,
"Swap in (out) the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"copy_blocks", "Reshape the key and value tensors and cache them");
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
} }
#pragma once #pragma once
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif #endif
...@@ -29,7 +30,8 @@ ...@@ -29,7 +30,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else #else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif #endif
...@@ -41,4 +43,3 @@ ...@@ -41,4 +43,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
int get_device_attribute( int get_device_attribute(int attribute, int device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute( int get_max_shared_memory_per_block_device_attribute(int device_id);
int device_id);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment