Unverified Commit b5caa22d authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[kernel] port rope cuda kernel to sgl-kernel (#2993)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 73401fd0
...@@ -222,3 +222,6 @@ work_dirs/ ...@@ -222,3 +222,6 @@ work_dirs/
compile_commands.json compile_commands.json
*.iml *.iml
# VSCode
.vscode
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sgl-kernel" name = "sgl-kernel"
version = "0.0.2.post14" version = "0.0.2.post15"
description = "Kernel Library for SGLang" description = "Kernel Library for SGLang"
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"
......
...@@ -53,6 +53,7 @@ ext_modules = [ ...@@ -53,6 +53,7 @@ ext_modules = [
"src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
], ],
include_dirs=include_dirs, include_dirs=include_dirs,
extra_compile_args={ extra_compile_args={
......
...@@ -6,6 +6,7 @@ from sgl_kernel.ops import ( ...@@ -6,6 +6,7 @@ from sgl_kernel.ops import (
int8_scaled_mm, int8_scaled_mm,
moe_align_block_size, moe_align_block_size,
register_graph_buffers, register_graph_buffers,
rotary_embedding,
sampling_scaling_penalties, sampling_scaling_penalties,
) )
...@@ -18,4 +19,5 @@ __all__ = [ ...@@ -18,4 +19,5 @@ __all__ = [
"sampling_scaling_penalties", "sampling_scaling_penalties",
"get_graph_buffer_ipc_meta", "get_graph_buffer_ipc_meta",
"register_graph_buffers", "register_graph_buffers",
"rotary_embedding",
] ]
// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, int rot_offset,
int embed_dim) {
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}
template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
}
void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
} else {
rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
}
});
}
...@@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma ...@@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias); const c10::optional<torch::Tensor>& bias);
// rotary embedding
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce // trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
...@@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm // int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
// rotary embedding
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
} }
...@@ -7,6 +7,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar ...@@ -7,6 +7,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from sgl_kernel.ops._kernels import ( from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties, sampling_scaling_penalties as _sampling_scaling_penalties,
) )
...@@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ...@@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
out_dtype, out_dtype,
bias, bias,
) )
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
from typing import Optional, Tuple
import torch
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding as VLLMRotaryEmbedding,
)
class SGLRotaryEmbedding(VLLMRotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sgl_kernel import rotary_embedding
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
def test_rotary_embedding():
# Test case 1: FP32
def run_test(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
batch_size,
seq_len,
num_heads,
test_name,
):
print(f"\nRunning {test_name}...")
# Initialize both implementations
sgl_rope = SGLRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
vllm_rope = VLLMRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
# Regular forward pass
positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)
key = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)
# Make copies for both implementations
query_sgl = query.clone()
key_sgl = key.clone()
query_vllm = query.clone()
key_vllm = key.clone()
# Run both implementations
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
positions, query_sgl, key_sgl
)
query_vllm_out, key_vllm_out = vllm_rope.forward_native(
positions, query_vllm, key_vllm
)
# Compare outputs
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)
print(f"{test_name} passed!")
# Test Case 1: FP32 with larger dimensions
run_test(
head_size=128,
rotary_dim=64,
max_position=4096,
base=10000,
is_neox_style=True,
dtype=torch.float32,
batch_size=4,
seq_len=32,
num_heads=8,
test_name="FP32 Test",
)
# Test Case 2: BF16 with smaller dimensions
run_test(
head_size=64,
rotary_dim=32,
max_position=2048,
base=8000,
is_neox_style=True,
dtype=torch.bfloat16,
batch_size=2,
seq_len=16,
num_heads=4,
test_name="BF16 Test",
)
if __name__ == "__main__":
test_rotary_embedding()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment