Unverified Commit 9aea2555 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)

parent fcc11e5e
import os
import sys
from contextlib import nullcontext
import torch
# NOTE copied and modified from DeepGEMM
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
# NOTE copied and modified from DeepGEMM
def bench_kineto(
fn,
kernel_names,
num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: str = None,
flush_l2: bool = True,
with_multiple_kernels: bool = False,
):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size = int(8e9 // 4)
# For some auto-tuning kernels with prints
fn()
# Profile
suppress = (
suppress_stdout_stderr
if suppress_kineto_output and not using_nsys
else nullcontext
)
with suppress():
schedule = (
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
if not using_nsys
else None
)
profiler = (
torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
)
if not using_nsys
else nullcontext()
)
with profiler:
for i in range(2):
for _ in range(num_tests):
if flush_l2:
torch.empty(
flush_l2_size, dtype=torch.int, device="cuda"
).zero_()
fn()
if not using_nsys:
profiler.step()
# Return 1 if using Nsight Systems
if using_nsys:
return 1
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = (
profiler.key_averages()
.table(sort_by="cuda_time_total", max_name_column_width=100)
.split("\n")
)
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
if not with_multiple_kernels:
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces
if trace_path is not None:
profiler.export_chrome_trace(trace_path)
# Return average kernel times
units = {"ms": 1e3, "us": 1e6}
kernel_times = []
for name in kernel_names:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += (
float(time_str.replace(unit, "")) / scale * int(num_str)
)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tuple else kernel_times[0]
import itertools
import torch
import triton
from sgl_kernel import FusedSetKVBufferArg
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
MHATokenToKVPool,
RotaryEmbedding,
create_inputs,
)
from sglang.srt.bench_utils import bench_kineto
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "save_kv_cache"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang"],
line_names=["SGL Kernel"],
styles=[("green", "-")],
ylabel="us",
plot_name="bench_rotary_embedding",
args={},
)
)
def benchmark(batch_size, seq_len, save_kv_cache, provider):
device = torch.device("cuda")
num_q_heads = 32
num_kv_heads = 8
head_size = 64
dtype = torch.bfloat16
config = dict(
head_size=head_size,
rotary_dim=64,
max_position_embeddings=4096,
base=8000,
is_neox_style=True,
dtype=dtype,
)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
bench_fn = lambda: rope_flashinfer.forward_cuda(
inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
)
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
return time_s * 1e6
if __name__ == "__main__":
benchmark.run(print_data=True)
...@@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); "Tensor pos_ids, bool interleave, int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
/* /*
......
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SGL_POS_ENC_CUH_
#define SGL_POS_ENC_CUH_
#include <flashinfer/pos_enc.cuh> // upstream
namespace flashinfer {
namespace kv_buffer_saver {
template <typename DType, typename IdType, uint32_t vec_size>
__device__ __forceinline__ void prepare(
vec_t<float, vec_size>& v_vec,
IdType& kv_cache_offset,
DType* v,
IdType* kv_cache_loc,
uint32_t idx,
uint32_t tx,
uint32_t kv_head_idx,
size_t v_stride_n,
size_t v_stride_h) {
kv_cache_offset = kv_cache_loc[idx];
DType* v_ptr = v + get_elem_offset_impl(idx, kv_head_idx, 0, v_stride_n, v_stride_h);
v_vec.cast_load(v_ptr + tx * vec_size);
}
template <typename DType, typename IdType, uint32_t vec_size>
__device__ __forceinline__ void save(
IdType& kv_cache_offset,
vec_t<float, vec_size>& k_vec,
vec_t<float, vec_size>& v_vec,
DType* k_buffer,
DType* v_buffer,
uint32_t idx,
uint32_t tx,
uint32_t kv_head_idx,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h) {
DType* k_buffer_ptr =
k_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, k_buffer_stride_n, k_buffer_stride_h);
DType* v_buffer_ptr =
v_buffer + get_elem_offset_impl(kv_cache_offset, kv_head_idx, 0, v_buffer_stride_n, v_buffer_stride_h);
k_vec.cast_store(k_buffer_ptr + tx * vec_size);
v_vec.cast_store(v_buffer_ptr + tx * vec_size);
}
} // namespace kv_buffer_saver
template <
bool save_kv_cache,
bool interleave,
uint32_t head_dim,
uint32_t vec_size,
uint32_t bdx,
typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* __restrict__ cos_sin_cache,
IdType* __restrict__ pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* __restrict__ kv_cache_loc) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
uint32_t by = blockIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const int half_rotary_dim = rotary_dim / 2;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if (tx * vec_size < rotary_dim) {
int sin_offset = rotary_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
}
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
}
if (by < num_qo_heads) {
uint32_t qo_head_idx = by;
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
} else {
uint32_t kv_head_idx = by - num_qo_heads;
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> v_vec;
IdType kv_cache_offset;
if constexpr (save_kv_cache) {
kv_buffer_saver::prepare<DType, IdType, vec_size>(
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
}
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
if constexpr (save_kv_cache) {
kv_buffer_saver::save<DType, IdType, vec_size>(
kv_cache_offset,
k_vec,
v_vec,
k_buffer,
v_buffer,
idx,
tx,
kv_head_idx,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h);
}
}
}
}
template <
bool save_kv_cache,
bool interleave,
uint32_t head_dim,
uint32_t vec_size,
uint32_t bdx,
typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* __restrict__ cos_sin_cache,
IdType* __restrict__ pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* __restrict__ kv_cache_loc) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
const IdType pos = pos_ids[idx];
const int half_rotary_dim = rotary_dim / 2;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if (tx * vec_size < rotary_dim) {
int sin_offset = rotary_dim / 2;
int vec_idx;
if constexpr (interleave) {
vec_idx = (tx * vec_size) / 2; // Force integer division
} else {
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
}
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
}
// not to unroll the loop, because num head might be large and might lead to worse performance
#pragma unroll 1
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
DType* q_rope_ptr = q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
vec_t<float, vec_size> q_vec;
if constexpr (interleave) {
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
} else {
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}
#pragma unroll 1
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
DType* k_rope_ptr = k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
vec_t<float, vec_size> v_vec;
IdType kv_cache_offset;
if constexpr (save_kv_cache) {
kv_buffer_saver::prepare<DType, IdType, vec_size>(
v_vec, kv_cache_offset, v, kv_cache_loc, idx, tx, kv_head_idx, v_stride_n, v_stride_h);
}
vec_t<float, vec_size> k_vec;
if constexpr (interleave) {
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
} else {
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
}
k_vec.cast_store(k_rope_ptr + tx * vec_size);
if constexpr (save_kv_cache) {
kv_buffer_saver::save<DType, IdType, vec_size>(
kv_cache_offset,
k_vec,
v_vec,
k_buffer,
v_buffer,
idx,
tx,
kv_head_idx,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h);
}
}
}
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
if (save_kv_cache) { \
const bool SAVE_KV_CACHE = true; \
__VA_ARGS__ \
} else { \
const bool SAVE_KV_CACHE = false; \
__VA_ARGS__ \
}
template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
DType* q,
DType* k,
DType* v,
DType* q_rope,
DType* k_rope,
DType* k_buffer,
DType* v_buffer,
float* cos_sin_cache,
IdType* pos_ids,
uint32_t nnz,
uint32_t num_qo_heads,
uint32_t num_kv_heads,
uint32_t rotary_dim,
uint32_t head_dim,
size_t q_stride_n,
size_t q_stride_h,
size_t k_stride_n,
size_t k_stride_h,
size_t v_stride_n,
size_t v_stride_h,
size_t q_rope_stride_n,
size_t q_rope_stride_h,
size_t k_rope_stride_n,
size_t k_rope_stride_h,
size_t k_buffer_stride_n,
size_t k_buffer_stride_h,
size_t v_buffer_stride_n,
size_t v_buffer_stride_h,
IdType* kv_cache_loc,
bool interleave,
bool save_kv_cache,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
// operate on 16 Bytes at a time
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
// how many threads needed per head_dim
constexpr uint32_t bdx = HEAD_DIM / vec_size;
// how many threads needed per block
uint32_t num_threads = std::max(128U, bdx);
// how many tokens can we process in a block
uint32_t bdy = num_threads / bdx;
// how many blocks needed to process all tokens
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
void* args[] = {
(void*)&q,
(void*)&k,
(void*)&v,
(void*)&q_rope,
(void*)&k_rope,
(void*)&k_buffer,
(void*)&v_buffer,
(void*)&cos_sin_cache,
(void*)&pos_ids,
(void*)&nnz,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&rotary_dim,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&v_stride_n,
(void*)&v_stride_h,
(void*)&q_rope_stride_n,
(void*)&q_rope_stride_h,
(void*)&k_rope_stride_n,
(void*)&k_rope_stride_h,
(void*)&k_buffer_stride_n,
(void*)&k_buffer_stride_h,
(void*)&v_buffer_stride_n,
(void*)&v_buffer_stride_h,
(void*)&kv_cache_loc};
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
SAVE_KV_CACHE,
INTERLEAVE,
HEAD_DIM,
vec_size,
bdx,
DType,
IdType>;
int num_blocks_per_sm_0 = 0;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
dim3 nblks(nblks_x);
dim3 nthrs(bdx, bdy);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
} else {
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
dim3 nthrs(bdx, bdy);
auto kernel_1 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel<
SAVE_KV_CACHE,
INTERLEAVE,
HEAD_DIM,
vec_size,
bdx,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
}
});
});
});
return cudaSuccess;
}
} // namespace flashinfer
#endif // SGL_POS_ENC_CUH_
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <flashinfer/pos_enc.cuh>
#include "pos_enc.cuh"
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils.h"
using namespace flashinfer; using namespace flashinfer;
...@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache, at::Tensor cos_sin_cache,
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
int64_t cuda_stream) { int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc) {
CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(q);
CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_LAST_DIM_CONTIGUOUS(k);
const bool save_kv_cache = v.has_value();
if (save_kv_cache) {
TORCH_CHECK(v.has_value());
TORCH_CHECK(k_buffer.has_value());
TORCH_CHECK(v_buffer.has_value());
TORCH_CHECK(kv_cache_loc.has_value());
CHECK_LAST_DIM_CONTIGUOUS(v.value());
CHECK_LAST_DIM_CONTIGUOUS(k_buffer.value());
CHECK_LAST_DIM_CONTIGUOUS(v_buffer.value());
CHECK_DIM(3, k_buffer.value()); // k_buffer: (nnz, H_K, D)
CHECK_DIM(3, v_buffer.value()); // v_buffer: (nnz, H_V, D)
CHECK_DIM(3, v.value()); // v: (nnz, H_V, D)
CHECK_DIM(1, kv_cache_loc.value()); // v: (n)
CHECK_INPUT(kv_cache_loc.value());
}
size_t k_buffer_stride_n = save_kv_cache ? k_buffer->stride(0) : 0;
size_t k_buffer_stride_h = save_kv_cache ? k_buffer->stride(1) : 0;
size_t v_buffer_stride_n = save_kv_cache ? v_buffer->stride(0) : 0;
size_t v_buffer_stride_h = save_kv_cache ? v_buffer->stride(1) : 0;
size_t v_stride_n = save_kv_cache ? v->stride(0) : 0;
size_t v_stride_h = save_kv_cache ? v->stride(1) : 0;
auto kv_cache_loc_ptr = save_kv_cache ? static_cast<int64_t*>(kv_cache_loc->data_ptr()) : nullptr;
CHECK_INPUT(cos_sin_cache); CHECK_INPUT(cos_sin_cache);
CHECK_INPUT(pos_ids); CHECK_INPUT(pos_ids);
auto device = q.device(); auto device = q.device();
...@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache(
CHECK_EQ(pos_ids.device(), device); CHECK_EQ(pos_ids.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D) CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_DIM(3, k); // k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R) // cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin // First half of R is cos, second half is sin
CHECK_DIM(2, cos_sin_cache); CHECK_DIM(2, cos_sin_cache);
...@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t q_stride_h = q.stride(1); size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0); size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1); size_t k_stride_h = k.stride(1);
size_t q_rope_stride_n = q_rope.stride(0); size_t q_rope_stride_n = q_rope.stride(0);
size_t q_rope_stride_h = q_rope.stride(1); size_t q_rope_stride_h = q_rope.stride(1);
size_t k_rope_stride_n = k_rope.stride(0); size_t k_rope_stride_n = k_rope.stride(0);
...@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache(
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( // TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
static_cast<c_type*>(q.data_ptr()), // to avoid changing original code path; but this branch is feature-complete and should switch to this later
static_cast<c_type*>(k.data_ptr()), if (save_kv_cache) {
static_cast<c_type*>(q_rope.data_ptr()), cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
static_cast<c_type*>(k_rope.data_ptr()), static_cast<c_type*>(q.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()), save_kv_cache ? static_cast<c_type*>(v->data_ptr()) : nullptr,
nnz, static_cast<c_type*>(q_rope.data_ptr()),
num_qo_heads, static_cast<c_type*>(k_rope.data_ptr()),
num_kv_heads, save_kv_cache ? static_cast<c_type*>(k_buffer->data_ptr()) : nullptr,
rotary_dim, save_kv_cache ? static_cast<c_type*>(v_buffer->data_ptr()) : nullptr,
head_dim, static_cast<float*>(cos_sin_cache.data_ptr()),
q_stride_n, static_cast<int64_t*>(pos_ids.data_ptr()),
q_stride_h, nnz,
k_stride_n, num_qo_heads,
k_stride_h, num_kv_heads,
q_rope_stride_n, rotary_dim,
q_rope_stride_h, head_dim,
k_rope_stride_n, q_stride_n,
k_rope_stride_h, q_stride_h,
interleave, k_stride_n,
stream); k_stride_h,
TORCH_CHECK( v_stride_n,
status == cudaSuccess, v_stride_h,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status))); q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
k_buffer_stride_n,
k_buffer_stride_h,
v_buffer_stride_n,
v_buffer_stride_h,
kv_cache_loc_ptr,
interleave,
save_kv_cache,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
std::string(cudaGetErrorString(status)));
} else {
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int64_t*>(pos_ids.data_ptr()),
nnz,
num_qo_heads,
num_kv_heads,
rotary_dim,
head_dim,
q_stride_n,
q_stride_h,
k_stride_n,
k_stride_h,
q_rope_stride_n,
q_rope_stride_h,
k_rope_stride_n,
k_rope_stride_h,
interleave,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + std::string(cudaGetErrorString(status)));
}
return true; return true;
}); });
} }
...@@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache, at::Tensor cos_sin_cache,
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
int64_t cuda_stream); int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
const std::optional<at::Tensor>& v_buffer,
const std::optional<at::Tensor>& kv_cache_loc);
#ifdef USE_ROCM #ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input); void gelu_quick(at::Tensor& out, const at::Tensor& input);
......
...@@ -21,6 +21,7 @@ from sgl_kernel.attention import ( ...@@ -21,6 +21,7 @@ from sgl_kernel.attention import (
) )
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
from sgl_kernel.elementwise import ( from sgl_kernel.elementwise import (
FusedSetKVBufferArg,
apply_rope_with_cos_sin_cache_inplace, apply_rope_with_cos_sin_cache_inplace,
fused_add_rmsnorm, fused_add_rmsnorm,
gelu_and_mul, gelu_and_mul,
......
from typing import Optional from dataclasses import dataclass
from typing import Any, Optional
import torch import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
...@@ -237,6 +238,31 @@ if torch.version.hip is not None: ...@@ -237,6 +238,31 @@ if torch.version.hip is not None:
return out return out
@dataclass
class FusedSetKVBufferArg:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value: torch.Tensor
k_buffer: torch.Tensor
v_buffer: torch.Tensor
k_scale: Optional[float]
v_scale: Optional[float]
cache_loc: torch.Tensor
def apply_rope_with_cos_sin_cache_inplace( def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
head_size: int, head_size: int,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool = True, is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> None: ) -> None:
r""" r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values. Apply rotary embedding to keys and queries with precomputed cos/sin values.
...@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e., * If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note Note
---- ----
The rotary dimension is determined by the cosine cache and sine cache. The rotary dimension is determined by the cosine cache and sine cache.
...@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32: if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
def _view_3d(x):
return x.view(x.shape[0], -1, head_size)
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
query.view(query.shape[0], -1, head_size), _view_3d(query),
key.view(key.shape[0], -1, head_size), _view_3d(key),
query.view(query.shape[0], -1, head_size), _view_3d(query),
key.view(key.shape[0], -1, head_size), _view_3d(key),
cos_sin_cache, cos_sin_cache,
positions.long(), positions.long(),
(not is_neox), (not is_neox),
get_cuda_stream(), get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.k_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.v_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
fused_set_kv_buffer_arg.cache_loc
if fused_set_kv_buffer_arg is not None
else None
),
) )
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
class RotaryEmbedding(torch.nn.Module):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
# Modification: float32 is required for the rotary embedding to work correctly
query = query.to(torch.float32)
key = key.to(torch.float32)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# Modification: convert to the correct dtype
query = query.to(self.dtype)
key = key.to(self.dtype)
return query, key
class FlashInferRotaryEmbedding(RotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
fused_set_kv_buffer_arg=fused_set_kv_buffer_arg,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
return query, key
class MHATokenToKVPool:
KV_POOL_SIZE = 16384
def __init__(
self,
head_num: int,
head_dim: int,
):
self.head_num = head_num
self.head_dim = head_dim
self.size = MHATokenToKVPool.KV_POOL_SIZE
self.page_size = 1
self.store_dtype = torch.bfloat16
self.device = "cuda"
self.layer_num = 1
self.start_layer = 0
self._create_buffers()
def _create_buffers(self):
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def set_kv_buffer(
self,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = 0
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def create_inputs(
head_size: int,
batch_size: int,
seq_len: int,
device,
dtype: torch.dtype,
num_q_heads: int,
num_kv_heads: int,
):
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
value = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
out_cache_loc = torch.randperm(
MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device
)[: batch_size * seq_len].clone()
return dict(
pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc
)
...@@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import pytest import pytest
import torch import torch
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
# vLLM torch native MHATokenToKVPool,
def _apply_rotary_emb( RotaryEmbedding,
x: torch.Tensor, create_inputs,
cos: torch.Tensor, )
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
class RotaryEmbedding(torch.nn.Module):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
# Modification: float32 is required for the rotary embedding to work correctly
query = query.to(torch.float32)
key = key.to(torch.float32)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# Modification: convert to the correct dtype
query = query.to(self.dtype)
key = key.to(self.dtype)
return query, key
class FlashInferRotaryEmbedding(RotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
return query, key
@pytest.mark.parametrize( @pytest.mark.parametrize(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache",
[ [
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), # GPT-OSS cases
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), *[
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), (
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), 64,
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), 64,
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), 4096,
8000,
True,
torch.bfloat16,
"cuda",
batch_size,
seq_len,
64,
8,
save_kv_cache,
)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
],
# Other cases
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False),
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False),
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False),
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
], ],
) )
def test_correctness( def test_correctness(
...@@ -163,34 +61,77 @@ def test_correctness( ...@@ -163,34 +61,77 @@ def test_correctness(
seq_len: int, seq_len: int,
num_q_heads: int, num_q_heads: int,
num_kv_heads: int, num_kv_heads: int,
save_kv_cache: bool,
): ):
rope_ref = RotaryEmbedding( config = dict(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size=head_size,
).to(device) rotary_dim=rotary_dim,
rope_flashinfer = FlashInferRotaryEmbedding( max_position_embeddings=max_position_embeddings,
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype base=base,
).to(device) is_neox_style=is_neox_style,
dtype=dtype,
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
) )
key = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device rope_ref = RotaryEmbedding(**config).to(device)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
) )
query_ref, key_ref = query.clone(), key.clone() if save_kv_cache:
query_flashinfer, key_flashinfer = query.clone(), key.clone() pool_ref = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone()
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
inputs["pos_ids"], query_ref, key_ref
)
if save_kv_cache:
pool_ref.set_kv_buffer(
loc=inputs["out_cache_loc"],
cache_k=key_ref_out.view(-1, num_kv_heads, head_size),
cache_v=inputs["value"].view(-1, num_kv_heads, head_size),
)
query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref)
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
pos_ids, query_flashinfer, key_flashinfer inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
) )
torch.testing.assert_close( torch.testing.assert_close(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
) )
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
if save_kv_cache:
for field in ["k_buffer", "v_buffer"]:
x_ref = getattr(pool_ref, field)[0]
x_flashinfer = getattr(pool_flashinfer, field)[0]
torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
nonzero_ref = x_ref != 0
nonzero_flashinfer = x_ref != 0
assert torch.all(nonzero_ref == nonzero_flashinfer)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment