Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
import argparse
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
size_m, size_k, size_n):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
group_size, size_m, size_k, size_n))
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
# Marlin quant
(
marlin_w_ref,
marlin_q_w,
marlin_s,
marlin_g_idx,
marlin_sort_indices,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
if act_order:
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
}
min_run_time = 1
# Warmup pytorch
for i in range(5):
torch.matmul(a, marlin_w_ref)
results.append(
benchmark.Timer(
stmt="torch.matmul(a, marlin_w_ref)",
globals=globals,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results = []
for model in args.models:
for layer in WEIGHT_SHAPES[model]:
size_k = layer[0]
size_n = layer[1]
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue
for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:
continue
# For act_order, the group_size must be less than
# size_k
if act_order and (group_size == size_k
or group_size == -1):
continue
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k,
size_n)
compare = benchmark.Compare(results)
compare.print()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)
...@@ -170,7 +170,7 @@ if __name__ == '__main__': ...@@ -170,7 +170,7 @@ if __name__ == '__main__':
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
...@@ -183,13 +183,11 @@ if __name__ == '__main__': ...@@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
choices=["auto", "fp8"], choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto", default="auto",
help= help="Data type for kv cache storage. If 'auto', will use model "
'Data type for kv cache storage. If "auto", will use model data type. ' "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
'FP8_E5M2 (without scaling) is only supported on cuda version greater ' "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument("--dtype",
......
WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [
[4096, 6144],
[4096, 4096],
[4096, 28672],
[14336, 4096],
],
"mistralai/Mistral-7B-v0.1/TP2": [
[4096, 3072],
[2048, 4096],
[4096, 14336],
[7168, 4096],
],
"mistralai/Mistral-7B-v0.1/TP4": [
[4096, 1536],
[1024, 4096],
[4096, 7168],
[3584, 4096],
],
"meta-llama/Llama-2-7b-hf/TP1": [
[4096, 12288],
[4096, 4096],
[4096, 22016],
[11008, 4096],
],
"meta-llama/Llama-2-7b-hf/TP2": [
[4096, 6144],
[2048, 4096],
[4096, 11008],
[5504, 4096],
],
"meta-llama/Llama-2-7b-hf/TP4": [
[4096, 3072],
[1024, 4096],
[4096, 5504],
[2752, 4096],
],
"meta-llama/Llama-2-13b-hf/TP1": [
[5120, 15360],
[5120, 5120],
[5120, 27648],
[13824, 5120],
],
"meta-llama/Llama-2-13b-hf/TP2": [
[5120, 7680],
[2560, 5120],
[5120, 13824],
[6912, 5120],
],
"meta-llama/Llama-2-13b-hf/TP4": [
[5120, 3840],
[1280, 5120],
[5120, 6912],
[3456, 5120],
],
"meta-llama/Llama-2-70b-hf/TP1": [
[8192, 10240],
[8192, 8192],
[8192, 57344],
[28672, 8192],
],
"meta-llama/Llama-2-70b-hf/TP2": [
[8192, 5120],
[4096, 8192],
[8192, 28672],
[14336, 8192],
],
"meta-llama/Llama-2-70b-hf/TP4": [
[8192, 2560],
[2048, 8192],
[8192, 14336],
[7168, 8192],
],
}
...@@ -4,7 +4,7 @@ PORT=8000 ...@@ -4,7 +4,7 @@ PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2
docker run --gpus all --shm-size 1g -p $PORT:80 \ docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \ -v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:1.4.0 \ ghcr.io/huggingface/text-generation-inference:1.4.0 \
--model-id $MODEL \ --model-id $MODEL \
......
import argparse
import cProfile
import pstats
from vllm import LLM, SamplingParams
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()
print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)
...@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags") "Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif() endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS list(REMOVE_ITEM GPU_FLAGS
...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
# "-DENABLE_FP8_E4M3" # "-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc" "-fno-gpu-rdc"
......
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
namespace vllm { namespace vllm {
// Activation and gating kernel template. // Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
...@@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( ...@@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
} }
} }
template<typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x))); return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) { __device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation. // Equivalent to PyTorch GELU with 'none' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x; const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2; constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation. // Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x; const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715; constexpr float KAPPA = 0.044715;
float x_cube = f * f * f; float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube); float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \ });
input.data_ptr<scalar_t>(), \
d); \ void silu_and_mul(torch::Tensor& out, // [..., d]
}); torch::Tensor& input) // [..., 2 * d]
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
...@@ -96,11 +90,11 @@ void gelu_tanh_and_mul( ...@@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
...@@ -108,54 +102,49 @@ __global__ void activation_kernel( ...@@ -108,54 +102,49 @@ __global__ void activation_kernel(
} }
} }
} // namespace vllm } // namespace vllm
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \ int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ });
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm { namespace vllm {
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) { __device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x); const float x3 = (float)(x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
return ((T) 0.5) * x * (((T) 1.0) + t); (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -22,31 +23,31 @@ ...@@ -22,31 +23,31 @@
namespace vllm { namespace vllm {
// A vector type to store Q, K, V elements. // A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE> template <typename T, int VEC_SIZE>
struct Vec {}; struct Vec {};
// A vector type to store FP32 accumulators. // A vector type to store FP32 accumulators.
template<typename T> template <typename T>
struct FloatVec {}; struct FloatVec {};
// Template vector operations. // Template vector operations.
template<typename Acc, typename A, typename B> template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b); inline __device__ Acc mul(A a, B b);
template<typename T> template <typename T>
inline __device__ float sum(T v); inline __device__ float sum(T v);
template<typename T> template <typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b)); return sum(mul<T, T, T>(a, b));
} }
template<typename A, typename T> template <typename A, typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b)); return sum(mul<A, T, T>(a, b));
} }
template<typename T> template <typename T>
inline __device__ void zero(T& dst) { inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4; constexpr int WORDS = sizeof(T) / 4;
union { union {
...@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { ...@@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw; dst = tmp.raw;
} }
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -19,27 +20,23 @@ ...@@ -19,27 +20,23 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#if defined(ENABLE_FP8_E5M2)
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16; #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
...@@ -49,7 +46,7 @@ ...@@ -49,7 +46,7 @@
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
template<int NUM_WARPS> template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) { inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane. // Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE; int warp = threadIdx.x / WARP_SIZE;
...@@ -86,31 +83,31 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -86,31 +83,31 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, bool IS_BLOCK_SPARSE,
int HEAD_SIZE, int PARTITION_SIZE = 0> // Zero means no partitioning.
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] // head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size/x, block_size, x]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int* __restrict__ seq_lens, // [num_seqs] // head_size, block_size]
const int max_num_blocks_per_seq, const int num_kv_heads, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads] const float scale,
const int q_stride, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int kv_block_stride, const int* __restrict__ seq_lens, // [num_seqs]
const int kv_head_stride, const int max_num_blocks_per_seq,
const float kv_scale) { const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
...@@ -122,22 +119,29 @@ __device__ void paged_attention_kernel( ...@@ -122,22 +119,29 @@ __device__ void paged_attention_kernel(
} }
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; const int start_block_idx =
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
...@@ -147,19 +151,18 @@ __device__ void paged_attention_kernel( ...@@ -147,19 +151,18 @@ __device__ void paged_attention_kernel(
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread
// fetch or compute 16 bytes at a time. // group fetch or compute 16 bytes at a time. For example, if the size of a
// For example, if the size of a thread group is 4 and the data type is half, // thread group is 4 and the data type is half, then the vector size is 16 /
// then the vector size is 16 / (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type; using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
...@@ -169,18 +172,21 @@ __device__ void paged_attention_kernel( ...@@ -169,18 +172,21 @@ __device__ void paged_attention_kernel(
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// th vectors of the query, and so on. // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
...@@ -199,51 +205,94 @@ __device__ void paged_attention_kernel( ...@@ -199,51 +205,94 @@ __device__ void paged_attention_kernel(
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 // blocksparse specific vars
// because int32 can lead to overflow when this variable is multiplied by large numbers int bs_block_offset;
// (e.g., kv_block_stride). int q_bs_block_id;
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group // For example, if the the thread group size is 4, then the first thread in
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride const cache_t* k_ptr =
+ kv_head_idx * kv_head_stride k_cache + physical_block_number * kv_block_stride +
+ physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2) if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = *reinterpret_cast<const K_vec*>(
// Vector conversion from Quant_vec to K_vec. k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#elif defined(ENABLE_FP8_E4M3)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
#else
assert(false);
#endif
} else { } else {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale);
} }
} }
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
...@@ -298,13 +347,12 @@ __device__ void paged_attention_kernel( ...@@ -298,13 +347,12 @@ __device__ void paged_attention_kernel(
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions seq_idx * num_heads * max_num_partitions +
+ partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+ head_idx * max_num_partitions head_idx * max_num_partitions + partition_idx;
+ partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum;
} }
...@@ -312,14 +360,13 @@ __device__ void paged_attention_kernel( ...@@ -312,14 +360,13 @@ __device__ void paged_attention_kernel(
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type; using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
using Float_L_vec = typename FloatVec<L_vec>::Type; using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
...@@ -330,44 +377,51 @@ __device__ void paged_attention_kernel( ...@@ -330,44 +377,51 @@ __device__ void paged_attention_kernel(
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 block_idx += NUM_WARPS) {
// because int32 can lead to overflow when this variable is multiplied by large numbers // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// (e.g., kv_block_stride). // int64 because int32 can lead to overflow when this variable is multiplied
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
+ kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec; V_vec v_vec;
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2) if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#elif defined(ENABLE_FP8_E4M3)
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
#else
assert(false);
#endif
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context, // NOTE(woosuk): When v_vec contains the tokens that are out of the
// we should explicitly zero out the values since they may contain NaNs. // context, we should explicitly zero out the values since they may
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
...@@ -390,8 +444,8 @@ __device__ void paged_attention_kernel( ...@@ -390,8 +444,8 @@ __device__ void paged_attention_kernel(
accs[i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for logits // NOTE(woosuk): A barrier is required because the shared memory space for
// is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
...@@ -428,9 +482,9 @@ __device__ void paged_attention_kernel( ...@@ -428,9 +482,9 @@ __device__ void paged_attention_kernel(
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE scalar_t* out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+ partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
...@@ -442,79 +496,84 @@ __device__ void paged_attention_kernel( ...@@ -442,79 +496,84 @@ __device__ void paged_attention_kernel(
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, bool IS_BLOCK_SPARSE>
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] // head_size/x, block_size, x]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size, block_size]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int num_kv_heads, // [num_heads]
const int* __restrict__ seq_lens, // [num_seqs] const float scale,
const int max_num_blocks_per_seq, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads] const int* __restrict__ seq_lens, // [num_seqs]
const int q_stride, const int max_num_blocks_per_seq,
const int kv_block_stride, const float* __restrict__ alibi_slopes, // [num_heads]
const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale) { const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>( const int blocksparse_vert_stride, const int blocksparse_block_size,
/* exp_sums */ nullptr, /* max_logits */ nullptr, const int blocksparse_head_sliding_step) {
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template< template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
typename scalar_t, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
typename cache_t, bool IS_BLOCK_SPARSE,
int HEAD_SIZE, int PARTITION_SIZE>
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] // max_num_partitions]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] // max_num_partitions, head_size]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int num_kv_heads, // [num_heads] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const float scale, // head_size/x, block_size, x]
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int* __restrict__ seq_lens, // [num_seqs] // head_size, block_size]
const int max_num_blocks_per_seq, const int num_kv_heads, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads] const float scale,
const int q_stride, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int kv_block_stride, const int* __restrict__ seq_lens, // [num_seqs]
const int kv_head_stride, const int max_num_blocks_per_seq,
const float kv_scale) { const float* __restrict__ alibi_slopes, // [num_heads]
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>( const int q_stride, const int kv_block_stride, const int kv_head_stride,
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, const int blocksparse_vert_stride, const int blocksparse_block_size,
q_stride, kv_block_stride, kv_head_stride, kv_scale); const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template< template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
typename scalar_t, int PARTITION_SIZE>
int HEAD_SIZE,
int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel( __global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ exp_sums, // [num_seqs, num_heads,
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] // max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const float* __restrict__ max_logits, // [num_seqs, num_heads,
const int* __restrict__ seq_lens, // [num_seqs] // max_num_partitions]
const int max_num_partitions) { const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
...@@ -522,9 +581,11 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -522,9 +581,11 @@ __global__ void paged_attention_v2_reduce_kernel(
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; scalar_t* out_ptr =
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ head_idx * max_num_partitions * HEAD_SIZE; const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i]; out_ptr[i] = tmp_out_ptr[i];
} }
...@@ -543,8 +604,9 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -543,8 +604,9 @@ __global__ void paged_attention_v2_reduce_kernel(
// Load max logits to shared memory. // Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions const float* max_logits_ptr = max_logits +
+ head_idx * max_num_partitions; seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX; float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i]; const float l = max_logits_ptr[i];
...@@ -573,9 +635,11 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -573,9 +635,11 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = VLLM_SHFL_SYNC(max_logit, 0); max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory. // Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); float* shared_exp_sums =
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+ head_idx * max_num_partitions; const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f; float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i]; float l = shared_max_logits[i];
...@@ -588,61 +652,52 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -588,61 +652,52 @@ __global__ void paged_attention_v2_reduce_kernel(
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out. // Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE const scalar_t* tmp_out_ptr =
+ head_idx * max_num_partitions * HEAD_SIZE; tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f; float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) { for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
} }
from_float(out_ptr[i], acc); from_float(out_ptr[i], acc);
} }
} }
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
IS_FP8_KV_CACHE>), shared_mem_size); \ BLOCK_SIZE, NUM_THREADS, \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ KV_DTYPE, IS_BLOCK_SPARSE>), \
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \ shared_mem_size); \
out_ptr, \ vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
query_ptr, \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
key_cache_ptr, \ <<<grid, block, shared_mem_size, stream>>>( \
value_cache_ptr, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
scale, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
block_tables_ptr, \ kv_scale, tp_rank, blocksparse_local_blocks, \
seq_lens_ptr, \ blocksparse_vert_stride, blocksparse_block_size, \
max_num_blocks_per_seq, \ blocksparse_head_sliding_step);
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template< template <typename T, typename CACHE_T, int BLOCK_SIZE,
typename T, vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
typename CACHE_T, int NUM_THREADS = 128>
int BLOCK_SIZE,
bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& value_cache, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
int num_kv_heads, const int tp_rank, const int blocksparse_local_blocks,
float scale, const int blocksparse_vert_stride, const int blocksparse_block_size,
torch::Tensor& block_tables, const int blocksparse_head_sliding_step) {
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -655,9 +710,10 @@ void paged_attention_v1_launcher( ...@@ -655,9 +710,10 @@ void paged_attention_v1_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
: nullptr; ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
...@@ -667,7 +723,8 @@ void paged_attention_v1_launcher( ...@@ -667,7 +723,8 @@ void paged_attention_v1_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float); int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
...@@ -697,6 +754,9 @@ void paged_attention_v1_launcher( ...@@ -697,6 +754,9 @@ void paged_attention_v1_launcher(
case 128: case 128:
LAUNCH_PAGED_ATTENTION_V1(128); LAUNCH_PAGED_ATTENTION_V1(128);
break; break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256: case 256:
LAUNCH_PAGED_ATTENTION_V1(256); LAUNCH_PAGED_ATTENTION_V1(256);
break; break;
...@@ -706,128 +766,93 @@ void paged_attention_v1_launcher( ...@@ -706,128 +766,93 @@ void paged_attention_v1_launcher(
} }
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
out, \ IS_BLOCK_SPARSE>( \
query, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
key_cache, \ seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
value_cache, \ blocksparse_local_blocks, blocksparse_vert_stride, \
num_kv_heads, \ blocksparse_block_size, blocksparse_head_sliding_step);
scale, \
block_tables, \ #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
seq_lens, \ switch (is_block_sparse) { \
max_seq_len, \ case true: \
alibi_slopes, \ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
kv_scale); break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \ break; \
case 16: \ case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \ case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
int num_kv_heads, // [num_heads] torch::Tensor&
float scale, value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] int num_kv_heads, // [num_heads]
torch::Tensor& seq_lens, // [num_seqs] float scale,
int block_size, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
int max_seq_len, torch::Tensor& seq_lens, // [num_seqs]
const c10::optional<torch::Tensor>& alibi_slopes, int block_size, int max_seq_len,
const std::string& kv_cache_dtype, const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) { const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
if (kv_cache_dtype == "auto") { const int blocksparse_local_blocks, const int blocksparse_vert_stride,
if (query.dtype() == at::ScalarType::Float) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); const bool is_block_sparse = (blocksparse_vert_stride > 1);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
} else if (query.dtype() == at::ScalarType::BFloat16) { CALL_V1_LAUNCHER_BLOCK_SIZE)
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \ vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
IS_FP8_KV_CACHE, PARTITION_SIZE> \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
<<<grid, block, shared_mem_size, stream>>>( \ PARTITION_SIZE> \
exp_sums_ptr, \ <<<grid, block, shared_mem_size, stream>>>( \
max_logits_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
tmp_out_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
query_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
key_cache_ptr, \ kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
value_cache_ptr, \ blocksparse_local_blocks, blocksparse_vert_stride, \
num_kv_heads, \ blocksparse_block_size, blocksparse_head_sliding_step); \
scale, \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
block_tables_ptr, \ PARTITION_SIZE> \
seq_lens_ptr, \ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
max_num_blocks_per_seq, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
alibi_slopes_ptr, \ max_num_partitions);
q_stride, \
kv_block_stride, \ template <typename T, typename CACHE_T, int BLOCK_SIZE,
kv_head_stride, \ vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
kv_scale); \ int NUM_THREADS = 128, int PARTITION_SIZE = 512>
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
max_num_partitions);
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher( void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
torch::Tensor& query, const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
torch::Tensor& key_cache, const int tp_rank, const int blocksparse_local_blocks,
torch::Tensor& value_cache, const int blocksparse_vert_stride, const int blocksparse_block_size,
int num_kv_heads, const int blocksparse_head_sliding_step) {
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -840,9 +865,10 @@ void paged_attention_v2_launcher( ...@@ -840,9 +865,10 @@ void paged_attention_v2_launcher(
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ? const float* alibi_slopes_ptr =
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) alibi_slopes
: nullptr; ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
...@@ -888,6 +914,9 @@ void paged_attention_v2_launcher( ...@@ -888,6 +914,9 @@ void paged_attention_v2_launcher(
case 128: case 128:
LAUNCH_PAGED_ATTENTION_V2(128); LAUNCH_PAGED_ATTENTION_V2(128);
break; break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256: case 256:
LAUNCH_PAGED_ATTENTION_V2(256); LAUNCH_PAGED_ATTENTION_V2(256);
break; break;
...@@ -897,84 +926,68 @@ void paged_attention_v2_launcher( ...@@ -897,84 +926,68 @@ void paged_attention_v2_launcher(
} }
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
out, \ IS_BLOCK_SPARSE>( \
exp_sums, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
max_logits, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
tmp_out, \ kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
query, \ blocksparse_block_size, blocksparse_head_sliding_step);
key_cache, \
value_cache, \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
num_kv_heads, \ switch (is_block_sparse) { \
scale, \ case true: \
block_tables, \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
seq_lens, \ break; \
max_seq_len, \ case false: \
alibi_slopes, \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
kv_scale); break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \ break; \
case 16: \ case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \ case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor&
torch::Tensor& query, // [num_seqs, num_heads, head_size] tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor&
int num_kv_heads, // [num_heads] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
float scale, torch::Tensor&
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& seq_lens, // [num_seqs] int num_kv_heads, // [num_heads]
int block_size, float scale,
int max_seq_len, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& seq_lens, // [num_seqs]
const std::string& kv_cache_dtype, int block_size, int max_seq_len,
float kv_scale) { const c10::optional<torch::Tensor>& alibi_slopes,
if (kv_cache_dtype == "auto") { const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
if (query.dtype() == at::ScalarType::Float) { const int blocksparse_local_blocks, const int blocksparse_vert_stride,
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
} else if (query.dtype() == at::ScalarType::Half) { const bool is_block_sparse = (blocksparse_vert_stride > 1);
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
} else if (query.dtype() == at::ScalarType::BFloat16) { CALL_V2_LAUNCHER_BLOCK_SIZE)
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -26,7 +27,7 @@ ...@@ -26,7 +27,7 @@
namespace vllm { namespace vllm {
// Q*K^T operation. // Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately). // Compute the parallel products for Q*K^T (treat vector lanes separately).
...@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { ...@@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template<typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template<typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
}; };
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -28,8 +30,8 @@ ...@@ -28,8 +30,8 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#include <stdint.h> #include <stdint.h>
...@@ -50,37 +52,37 @@ struct bf16_8_t { ...@@ -50,37 +52,37 @@ struct bf16_8_t {
}; };
// BF16 vector types for Q, K, V. // BF16 vector types for Q, K, V.
template<> template <>
struct Vec<__nv_bfloat16, 1> { struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16; using Type = __nv_bfloat16;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 2> { struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162; using Type = __nv_bfloat162;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 4> { struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t; using Type = bf16_4_t;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 8> { struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t; using Type = bf16_8_t;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<__nv_bfloat16> { struct FloatVec<__nv_bfloat16> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<__nv_bfloat162> { struct FloatVec<__nv_bfloat162> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<bf16_4_t> { struct FloatVec<bf16_4_t> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<bf16_8_t> { struct FloatVec<bf16_8_t> {
using Type = Float8_; using Type = Float8_;
}; };
...@@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
// assert(false); // assert(false);
// #else // #else
#ifndef USE_ROCM #ifndef USE_ROCM
return a + b; return a + b;
#else #else
return __hadd(a, b); return __hadd(a, b);
#endif #endif
// #endif // #endif
} }
...@@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { ...@@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
...@@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { ...@@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
// #endif // #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
...@@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { ...@@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
// #endif // #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c; bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { ...@@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_4_t c; bf16_4_t c;
...@@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { ...@@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c; bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { ...@@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_8_t c; bf16_8_t c;
...@@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a); float fa = __bfloat162float(a);
float fb = __bfloat162float(b); float fb = __bfloat162float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a); float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b); float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { ...@@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float4_ fc; Float4_ fc;
...@@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { ...@@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
...@@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { ...@@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float8_ fc; Float8_ fc;
...@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { ...@@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
// #else // #else
...@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf ...@@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
// #endif // #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// assert(false); // assert(false);
// #else // #else
...@@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { ...@@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(__nv_bfloat16 v) { inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v); return __bfloat162float(v);
} }
template<> template <>
inline __device__ float sum(__nv_bfloat162 v) { inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v); float2 vf = bf1622float2(v);
return vf.x + vf.y; return vf.x + vf.y;
} }
template<> template <>
inline __device__ float sum(bf16_4_t v) { inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y); return sum(v.x) + sum(v.y);
} }
template<> template <>
inline __device__ float sum(bf16_8_t v) { inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
} }
...@@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { ...@@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
// #endif // #endif
} }
} // namespace vllm } // namespace vllm
\ No newline at end of file
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -30,37 +32,37 @@ ...@@ -30,37 +32,37 @@
namespace vllm { namespace vllm {
// FP16 vector types for Q, K, V. // FP16 vector types for Q, K, V.
template<> template <>
struct Vec<uint16_t, 1> { struct Vec<uint16_t, 1> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint16_t, 2> { struct Vec<uint16_t, 2> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint16_t, 4> { struct Vec<uint16_t, 4> {
using Type = uint2; using Type = uint2;
}; };
template<> template <>
struct Vec<uint16_t, 8> { struct Vec<uint16_t, 8> {
using Type = uint4; using Type = uint4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<uint16_t> { struct FloatVec<uint16_t> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<uint32_t> { struct FloatVec<uint32_t> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<uint2> { struct FloatVec<uint2> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<uint4> { struct FloatVec<uint4> {
using Type = Float8_; using Type = Float8_;
}; };
...@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { ...@@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
return b; return b;
#else #else
union { union {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
tmp.u16[0] = a; tmp.u16[0] = a;
tmp.u16[1] = a; tmp.u16[1] = a;
...@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { ...@@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif #endif
#else #else
tmp.u16[0] = float_to_half(f.x); tmp.u16[0] = float_to_half(f.x);
...@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { ...@@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) { inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { ...@@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) { inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { ...@@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) { inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b); return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ uint2 mul(uint2 a, uint2 b) { inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c; uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
...@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { ...@@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint2 mul(uint16_t a, uint2 b) { inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint2 c; uint2 c;
...@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { ...@@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint4 a, uint4 b) { inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c; uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
...@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { ...@@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint16_t a, uint4 b) { inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint4 c; uint4 c;
...@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { ...@@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(uint16_t a, uint16_t b) { inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a); float fa = half_to_float(a);
float fb = half_to_float(b); float fb = half_to_float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) { inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a); float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b); float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(uint16_t a, uint32_t b) { inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b); return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(uint2 a, uint2 b) { inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
...@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { ...@@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(uint16_t a, uint2 b) { inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float4_ fc; Float4_ fc;
...@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { ...@@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint4 a, uint4 b) { inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
...@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { ...@@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float8_ fc; Float8_ fc;
...@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { ...@@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
...@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { ...@@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(uint16_t v) { inline __device__ float sum(uint16_t v) {
return half_to_float(v); return half_to_float(v);
} }
template<> template <>
inline __device__ float sum(uint32_t v) { inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v); float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y; return tmp.x + tmp.y;
} }
template<> template <>
inline __device__ float sum(uint2 v) { inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
return sum(c); return sum(c);
} }
template<> template <>
inline __device__ float sum(uint4 v) { inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
c = add(c, v.z); c = add(c, v.z);
...@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { ...@@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
...@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { ...@@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
...@@ -38,37 +40,35 @@ struct Float8_ { ...@@ -38,37 +40,35 @@ struct Float8_ {
}; };
// FP32 vector types for Q, K, V. // FP32 vector types for Q, K, V.
template<> template <>
struct Vec<float, 1> { struct Vec<float, 1> {
using Type = float; using Type = float;
}; };
template<> template <>
struct Vec<float, 2> { struct Vec<float, 2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct Vec<float, 4> { struct Vec<float, 4> {
using Type = float4; using Type = float4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<float> { struct FloatVec<float> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<float2> { struct FloatVec<float2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<float4> { struct FloatVec<float4> {
using Type = float4; using Type = float4;
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
...@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { ...@@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ float mul<float, float>(float a, float b) { inline __device__ float mul<float, float>(float a, float b) {
return a * b; return a * b;
} }
template<> template <>
inline __device__ float2 mul(float2 a, float2 b) { inline __device__ float2 mul(float2 a, float2 b) {
float2 c; float2 c;
c.x = a.x * b.x; c.x = a.x * b.x;
...@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { ...@@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float2 mul(float a, float2 b) { inline __device__ float2 mul(float a, float2 b) {
float2 c; float2 c;
c.x = a * b.x; c.x = a * b.x;
...@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { ...@@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float4 a, float4 b) { inline __device__ float4 mul(float4 a, float4 b) {
float4 c; float4 c;
c.x = a.x * b.x; c.x = a.x * b.x;
...@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { ...@@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float a, float4 b) { inline __device__ float4 mul(float a, float4 b) {
float4 c; float4 c;
c.x = a * b.x; c.x = a * b.x;
...@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { ...@@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
...@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { ...@@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(float v) { inline __device__ float sum(float v) {
return v; return v;
} }
template<> template <>
inline __device__ float sum(float2 v) { inline __device__ float sum(float2 v) {
return v.x + v.y; return v.x + v.y;
} }
template<> template <>
inline __device__ float sum(float4 v) { inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w; return v.x + v.y + v.z + v.w;
} }
template<> template <>
inline __device__ float sum(Float4_ v) { inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y; return v.x.x + v.x.y + v.y.x + v.y.y;
} }
template<> template <>
inline __device__ float sum(Float8_ v) { inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
...@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { ...@@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm
...@@ -3,33 +3,39 @@ ...@@ -3,33 +3,39 @@
#include "attention_generic.cuh" #include "attention_generic.cuh"
#include <stdint.h> #include <stdint.h>
#ifdef ENABLE_FP8_E5M2 #ifdef ENABLE_FP8
#include <cuda_fp8.h> #ifndef USE_ROCM
#endif #include <cuda_fp8.h>
#endif // USE_ROCM
#endif // ENABLE_FP8
namespace vllm { namespace vllm {
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
template<> enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
// fp8 vector types for quantization of kv cache
template <>
struct Vec<uint8_t, 1> { struct Vec<uint8_t, 1> {
using Type = uint8_t; using Type = uint8_t;
}; };
template<> template <>
struct Vec<uint8_t, 2> { struct Vec<uint8_t, 2> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint8_t, 4> { struct Vec<uint8_t, 4> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint8_t, 8> { struct Vec<uint8_t, 8> {
using Type = uint2; using Type = uint2;
}; };
#endif // ENABLE_FP8_E5M2
} // namespace vllm } // namespace vllm
...@@ -5,34 +5,24 @@ ...@@ -5,34 +5,24 @@
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping);
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping);
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value, torch::Tensor& slot_mapping,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, const float kv_scale);
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const float kv_scale);
void reshape_and_cache_flash( void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache,
torch::Tensor& value, torch::Tensor& value_cache,
torch::Tensor& key_cache, torch::Tensor& slot_mapping,
torch::Tensor& value_cache, const std::string& kv_cache_dtype);
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& src_cache, const float scale, const std::string& kv_cache_dtype);
torch::Tensor& dst_cache);
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" #ifdef USE_ROCM
#elif defined(ENABLE_FP8_E4M3) #include "quantization/fp8/amd/quant_utils.cuh"
#include "quantization/fp8/amd_detail/quant_utils.cuh" #else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include <algorithm> #include <algorithm>
...@@ -17,20 +18,17 @@ ...@@ -17,20 +18,17 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping) {
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) { if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice; memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) { } else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost; memcpy_type = cudaMemcpyDeviceToHost;
...@@ -40,41 +38,44 @@ void swap_blocks( ...@@ -40,41 +38,44 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination"); TORCH_CHECK(false, "Invalid device combination");
} }
char *src_ptr = static_cast<char*>(src.data_ptr()); // NOTE(youkaichao): keep in mind that `block_mapping` should be
char *dst_ptr = static_cast<char*>(dst.data_ptr()); // a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) { const int64_t num_blocks = block_mapping.size(0);
int64_t src_block_number = pair.first; for (size_t i = 0; i < num_blocks; i++) {
int64_t dst_block_number = pair.second; int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
dst_ptr + dst_offset, block_size_in_bytes, memcpy_type, stream);
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
} }
} }
namespace vllm { namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)
template<typename scalar_t> template <typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* key_cache_ptrs, int64_t* value_cache_ptrs,
int64_t* value_cache_ptrs, const int64_t* __restrict__ block_mapping,
const int64_t* __restrict__ block_mapping, const int numel_per_block) {
const int numel_per_block) {
const int layer_idx = blockIdx.x; const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
...@@ -92,12 +93,11 @@ __global__ void copy_blocks_kernel( ...@@ -92,12 +93,11 @@ __global__ void copy_blocks_kernel(
} }
} }
} // namespace vllm } // namespace vllm
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping) {
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
...@@ -111,29 +111,23 @@ void copy_blocks( ...@@ -111,29 +111,23 @@ void copy_blocks(
int64_t key_cache_ptrs[num_layers]; int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] =
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
} value_cache_ptrs[layer_idx] =
// Create block mapping array. reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
} }
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2; // block_mapping is a 2D tensor with shape (num_pairs, 2).
int num_pairs = block_mapping.size(0);
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob( torch::Tensor key_cache_ptrs_tensor =
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( .to(cache_device);
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor =
torch::Tensor block_mapping_tensor = torch::from_blob( torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); .to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
...@@ -142,31 +136,28 @@ void copy_blocks( ...@@ -142,31 +136,28 @@ void copy_blocks(
const at::cuda::OptionalCUDAGuard device_guard(cache_device); const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(), numel_per_block);
numel_per_block); }));
}));
} }
namespace vllm { namespace vllm {
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] // block_size, x]
const int64_t* __restrict__ slot_mapping, // [num_tokens] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
const int key_stride, // block_size]
const int value_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int head_size, const int block_size, const int x,
const int block_size, const float kv_scale) {
const int x,
const float kv_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
...@@ -187,47 +178,39 @@ __global__ void reshape_and_cache_kernel( ...@@ -187,47 +178,39 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx =
+ head_idx * (head_size / x) * block_size * x block_idx * num_heads * (head_size / x) * block_size * x +
+ x_idx * block_size * x head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x block_offset * x + x_offset;
+ x_offset; const int64_t tgt_value_idx =
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size head_idx * head_size * block_size + head_offset * block_size +
+ head_offset * block_size block_offset;
+ block_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_kv_cache) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
} }
} }
} }
template<typename scalar_t> template <typename scalar_t>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] // head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
const int block_stride, // head_size]
const int key_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int value_stride, const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int num_heads, const int head_size, const int block_size) {
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
...@@ -242,40 +225,37 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -242,40 +225,37 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t src_value_idx = token_idx * value_stride + i; const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride const int64_t tgt_value_idx = block_idx * block_stride +
+ block_offset * num_heads * head_size block_offset * num_heads * head_size +
+ head_idx * head_size head_idx * head_size + head_offset;
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx]; k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx]; v_cache[tgt_value_idx] = value[src_value_idx];
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ // KV_T is the stored data type of kv-cache.
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \ // CACHE_T is the data type of key and value tensors.
reinterpret_cast<KV_T*>(key.data_ptr()), \ // KV_DTYPE is the real data type of kv-cache.
reinterpret_cast<KV_T*>(value.data_ptr()), \ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ <<<grid, block, 0, stream>>>( \
slot_mapping.data_ptr<int64_t>(), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
key_stride, \ reinterpret_cast<KV_T*>(value.data_ptr()), \
value_stride, \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
num_heads, \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
head_size, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
block_size, \ num_heads, head_size, block_size, x, kv_scale);
x, \
kv_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor&
const std::string& kv_cache_dtype, value_cache, // [num_blocks, num_heads, head_size, block_size]
const float kv_scale) torch::Tensor& slot_mapping, // [num_tokens]
{ const std::string& kv_cache_dtype, const float kv_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
...@@ -289,35 +269,18 @@ void reshape_and_cache( ...@@ -289,35 +269,18 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) { DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE(float, float, false); CALL_RESHAPE_AND_CACHE)
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
} }
void reshape_and_cache_flash( void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) const std::string& kv_cache_dtype) {
{
// FIXME: only support auto datatype, does not support fp8 // FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") { if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
...@@ -337,63 +300,47 @@ void reshape_and_cache_flash( ...@@ -337,63 +300,47 @@ void reshape_and_cache_flash(
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), key.scalar_type(), "reshape_and_cache_flash", [&] {
"reshape_and_cache_flash", vllm::reshape_and_cache_flash_kernel<scalar_t>
[&] { <<<grid, block, 0, stream>>>(
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
k_cache.data_ptr<scalar_t>(), value_stride, num_heads, head_size, block_size);
v_cache.data_ptr<scalar_t>(), });
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
} }
namespace vllm { namespace vllm {
template<typename Tout, typename Tin> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel( __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache,
Tout* __restrict__ dst_cache, const float kv_scale,
const int64_t block_stride) { const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
#if defined(ENABLE_FP8_E5M2) dst_cache[idx] =
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]); fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \ reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
block_stride);
void convert_fp8( // Only for testing.
torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& dst_cache) const float kv_scale, const std::string& kv_cache_dtype) {
{
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device); at::cuda::OptionalCUDAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0); int64_t num_blocks = src_cache.size(0);
...@@ -403,17 +350,37 @@ void convert_fp8( ...@@ -403,17 +350,37 @@ void convert_fp8(
dim3 block(std::min(block_stride, int64_t(512))); dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) { if (kv_cache_dtype == "auto") {
CALL_CONVERT_FP8(uint8_t, float); if (src_cache.dtype() == at::ScalarType::Float) {
} else if (src_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8(uint8_t, uint16_t); } else if (src_cache.dtype() == at::ScalarType::Half) {
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
} else if (dst_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8(float, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Float) {
} else if (dst_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8(uint16_t, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Half) {
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
} }
#include "cpu_types.hpp" #include "cpu_types.hpp"
namespace { namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &), template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated> bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t *__restrict__ output) { scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
...@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, ...@@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
} }
} }
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp()); return x / (ones + (zeros - x).exp());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
...@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { ...@@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
...@@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { ...@@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er()); return x * w2 * (ones + (x * w1).er());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
...@@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { ...@@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh()); return x * w2 * (ones + inner.tanh());
} }
}; // namespace }; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) activation_kernel<scalar_t, silu_act, true>(
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
} }
void gelu_and_mul(torch::Tensor &out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) activation_kernel<scalar_t, gelu_act, true>(
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
} }
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
...@@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] ...@@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
}); });
} }
void gelu_new(torch::Tensor &out, torch::Tensor &input) { void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
...@@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { ...@@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
}); });
} }
void gelu_fast(torch::Tensor &out, torch::Tensor &input) { void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
namespace { namespace {
template <typename scalar_t> struct KernelVecType { template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void; using q_load_vec_type = void;
using q_vec_type = void; using q_vec_type = void;
using k_load_vec_type = void; using k_load_vec_type = void;
...@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType { ...@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void; using v_load_vec_type = void;
}; };
template <> struct KernelVecType<float> { template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4; using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
...@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> { ...@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
}; };
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32; using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32;
...@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> { ...@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#else #else
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16; using k_load_vec_type = vec_op::BF16Vec16;
...@@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> { ...@@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
#endif #endif
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
const int capacity) { const int capacity) {
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
...@@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, ...@@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
} }
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
reduceSoftmaxAlibi(T *data, const int size, const int capacity, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope,
const int seq_len) { const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1); data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
...@@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, ...@@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
} }
template <typename T> template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
const int size) { const int size) {
T max = max_data[0]; T max = max_data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
...@@ -132,9 +137,9 @@ struct reduceQKBlockKernel { ...@@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q, FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t *__restrict__ k_block, const scalar_t* __restrict__ k_block,
float *__restrict__ logits, float scale, float* __restrict__ logits, float scale,
const int token_num) { const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
...@@ -196,8 +201,8 @@ struct reduceQKBlockKernel { ...@@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t> int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t &&acc) { acc_t&& acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM); static_assert(BLOCK_SIZE == ELEM_NUM);
...@@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, ...@@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
}); });
} }
}; // namespace }; // namespace
// Paged attention v1 // Paged attention v1
namespace { namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl { struct paged_attention_v1_impl {
static void static void call(
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) { const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
...@@ -243,32 +248,31 @@ struct paged_attention_v1_impl { ...@@ -243,32 +248,31 @@ struct paged_attention_v1_impl {
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc( float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int *seq_block_table = const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
seq_len - (block_num - 1) * BLOCK_SIZE; float* __restrict__ thread_block_logits =
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded; logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
...@@ -282,8 +286,7 @@ struct paged_attention_v1_impl { ...@@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, seq_len, reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
block_num * BLOCK_SIZE);
} }
// Compute value // Compute value
...@@ -293,14 +296,14 @@ struct paged_attention_v1_impl { ...@@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition; head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -311,7 +314,7 @@ struct paged_attention_v1_impl { ...@@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -340,16 +343,16 @@ struct paged_attention_v1_impl { ...@@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
template <typename T, int BLOCK_SIZE> template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -359,68 +362,73 @@ void paged_attention_v1_impl_launcher( ...@@ -359,68 +362,73 @@ void paged_attention_v1_impl_launcher(
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 192:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break; break;
default: case 256:
TORCH_CHECK(false, "Unsupported head size: ", head_size); LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
} }
} }
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \ CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, void paged_attention_v1(
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor &seq_lens, int block_size, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int max_seq_len, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const c10::optional<torch::Tensor> &alibi_slopes, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const std::string &kv_cache_dtype, float kv_scale) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
...@@ -434,23 +442,24 @@ namespace { ...@@ -434,23 +442,24 @@ namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl { struct paged_attention_v2_impl {
static void call( static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads,
float // max_num_partitions]
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions]
// max_num_partitions, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] // max_num_partitions, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x]
// head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) { const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
...@@ -468,8 +477,7 @@ struct paged_attention_v2_impl { ...@@ -468,8 +477,7 @@ struct paged_attention_v2_impl {
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len) if (start_token_idx >= seq_len) continue;
continue;
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
...@@ -477,15 +485,14 @@ struct paged_attention_v2_impl { ...@@ -477,15 +485,14 @@ struct paged_attention_v2_impl {
const int token_num = const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables + const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
...@@ -493,10 +500,10 @@ struct paged_attention_v2_impl { ...@@ -493,10 +500,10 @@ struct paged_attention_v2_impl {
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
...@@ -510,13 +517,13 @@ struct paged_attention_v2_impl { ...@@ -510,13 +517,13 @@ struct paged_attention_v2_impl {
logits, token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, token_num, max_and_sum =
block_num * BLOCK_SIZE); reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
} }
auto &&[max_logit, exp_sum] = max_and_sum; auto&& [max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr; scalar_t* __restrict__ output_buffer = nullptr;
if (!no_reduce) { if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions + auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
...@@ -538,13 +545,13 @@ struct paged_attention_v2_impl { ...@@ -538,13 +545,13 @@ struct paged_attention_v2_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition; output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -555,7 +562,7 @@ struct paged_attention_v2_impl { ...@@ -555,7 +562,7 @@ struct paged_attention_v2_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
...@@ -587,8 +594,7 @@ struct paged_attention_v2_impl { ...@@ -587,8 +594,7 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
reducePartitonSoftmax( reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions + max_logits + seq_idx * num_heads * max_num_partitions +
...@@ -603,11 +609,11 @@ struct paged_attention_v2_impl { ...@@ -603,11 +609,11 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group = constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE 16; // Note: didn't align with the cacheline size, due to some
// didn't align with 64 bytes // HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0); static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums; const float* __restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1) #pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
...@@ -616,17 +622,16 @@ struct paged_attention_v2_impl { ...@@ -616,17 +622,16 @@ struct paged_attention_v2_impl {
const int partition_num = const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
const float *__restrict__ seq_head_rescale_factors = const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions + rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions; head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out = const scalar_t* __restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output = scalar_t* __restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
...@@ -645,21 +650,21 @@ struct paged_attention_v2_impl { ...@@ -645,21 +650,21 @@ struct paged_attention_v2_impl {
} }
}; };
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512> template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher( void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -670,73 +675,78 @@ void paged_attention_v2_impl_launcher( ...@@ -670,73 +675,78 @@ void paged_attention_v2_impl_launcher(
int max_num_partitions = exp_sums.size(-1); int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 192:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break; break;
default: case 256:
TORCH_CHECK(false, "Unsupported head size: ", head_size); LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
} }
} }
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
max_seq_len, alibi_slopes); alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \ CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, void paged_attention_v2(
torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor& value_cache, int num_kv_heads, float scale,
float scale, torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor &seq_lens, int block_size, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int max_seq_len, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const c10::optional<torch::Tensor> &alibi_slopes, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const std::string &kv_cache_dtype, float kv_scale) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
......
...@@ -5,25 +5,26 @@ ...@@ -5,25 +5,26 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor> &value_caches, const torch::Tensor& mapping_pairs,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs, const int element_num_per_block,
const int element_num_per_block, const int layer_num) { const int layer_num) {
const size_t pair_num = mapping_pairs.size(); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second; element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset; scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>(); scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset; source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset; target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
...@@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( ...@@ -33,9 +34,9 @@ void copy_blocks_cpu_impl(
template <typename scalar_t> template <typename scalar_t>
void reshape_and_cache_cpu_impl( void reshape_and_cache_cpu_impl(
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens, const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) { const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size; const int block_elem_num = num_heads * head_size * block_size;
...@@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( ...@@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl(
int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx = int src_value_head_idx =
token_idx * value_stride + head_idx * head_size; token_idx * value_stride + head_idx * head_size;
const scalar_t *src_key_head_ptr = key + src_key_head_idx; const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx; const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size; const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size; const int64_t block_offset = slot_idx % block_size;
scalar_t *target_key_head_ptr = key_cache + scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
scalar_t *target_value_head_ptr = value_cache + scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
...@@ -79,39 +80,31 @@ void reshape_and_cache_cpu_impl( ...@@ -79,39 +80,31 @@ void reshape_and_cache_cpu_impl(
} }
} }
} }
}; // namespace }; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches, void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) { const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
return; return;
} }
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel(); const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs, copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers); element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
}); });
} }
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor &slot_mapping, torch::Tensor& slot_mapping,
const std::string &kv_cache_dtype, float kv_scale) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
...@@ -135,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, ...@@ -135,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
}); });
} }
void swap_blocks(torch::Tensor &src, torch::Tensor &dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const std::map<int64_t, int64_t> &block_mapping) { const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment