Unverified Commit fb11a439 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[kernel] Integrate flashinfer's rope with higher precision and better perf (#3134)

parent af02f99b
Subproject commit 6e6f38d3534994c34b2c6b09b5b45c8a7b92ffd2
Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5
......@@ -94,6 +94,7 @@ sources = [
"3rdparty/flashinfer/csrc/norm.cu",
"3rdparty/flashinfer/csrc/sampling.cu",
"3rdparty/flashinfer/csrc/renorm.cu",
"3rdparty/flashinfer/csrc/rope.cu",
]
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
......
from sgl_kernel.ops import (
apply_rope_with_cos_sin_cache_inplace,
bmm_fp8,
custom_dispose,
custom_reduce,
......@@ -25,6 +26,7 @@ from sgl_kernel.ops import (
)
__all__ = [
"apply_rope_with_cos_sin_cache_inplace",
"bmm_fp8",
"custom_dispose",
"custom_reduce",
......
......@@ -98,7 +98,7 @@ void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [nu
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 grid(num_tokens); // each block is responsible for one token
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
......@@ -112,3 +112,7 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
int64_t cuda_stream);
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
int64_t cuda_stream);
......@@ -10,6 +10,60 @@ from sgl_kernel.ops.utils import (
)
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
with query.device as device:
pos_ids = pos_ids.int()
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions,
interleave=(not is_neox),
cuda_stream=_get_cuda_stream(device),
)
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
......
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
......@@ -116,6 +115,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
// apply rope with cos sin cache
m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
}
REGISTER_EXTENSION(_kernels)
from typing import Optional, Tuple
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding as VLLMRotaryEmbedding,
)
import torch.nn as nn
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
class RotaryEmbedding(torch.nn.Module):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
class SGLRotaryEmbedding(VLLMRotaryEmbedding):
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_cuda(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sgl_kernel import rotary_embedding
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
# Modification: float32 is required for the rotary embedding to work correctly
query = query.to(torch.float32)
key = key.to(torch.float32)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# Modification: convert to the correct dtype
query = query.to(self.dtype)
key = key.to(self.dtype)
return query, key
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
def test_rotary_embedding():
# Test case 1: FP32
def run_test(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
batch_size,
seq_len,
num_heads,
test_name,
):
print(f"\nRunning {test_name}...")
# Initialize both implementations
sgl_rope = SGLRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
vllm_rope = VLLMRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
# Regular forward pass
positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)
key = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
class FlashInferRotaryEmbedding(RotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
# Make copies for both implementations
query_sgl = query.clone()
key_sgl = key.clone()
query_vllm = query.clone()
key_vllm = key.clone()
return query, key
# Run both implementations
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
positions, query_sgl, key_sgl
)
query_vllm_out, key_vllm_out = vllm_rope.forward_native(
positions, query_vllm, key_vllm
)
# Compare outputs
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)
print(f"{test_name} passed!")
# Test Case 1: FP32 with larger dimensions
run_test(
head_size=128,
rotary_dim=64,
max_position=4096,
base=10000,
is_neox_style=True,
dtype=torch.float32,
batch_size=4,
seq_len=32,
num_heads=8,
test_name="FP32 Test",
@pytest.mark.parametrize(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads",
[
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2),
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2),
],
)
def test_correctness(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
rope_ref = RotaryEmbedding(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
).to(device)
rope_flashinfer = FlashInferRotaryEmbedding(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
# Test Case 2: BF16 with smaller dimensions
run_test(
head_size=64,
rotary_dim=32,
max_position=2048,
base=8000,
is_neox_style=True,
dtype=torch.bfloat16,
batch_size=2,
seq_len=16,
num_heads=4,
test_name="BF16 Test",
query_ref, key_ref = query.clone(), key.clone()
query_flashinfer, key_flashinfer = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref)
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
pos_ids, query_flashinfer, key_flashinfer
)
print(query_ref_out)
print(query_flashinfer_out)
if __name__ == "__main__":
test_rotary_embedding()
torch.testing.assert_close(
query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment