Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb11a439
Unverified
Commit
fb11a439
authored
Jan 26, 2025
by
Byron Hsu
Committed by
GitHub
Jan 27, 2025
Browse files
[kernel] Integrate flashinfer's rope with higher precision and better perf (#3134)
parent
af02f99b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
244 additions
and
98 deletions
+244
-98
sgl-kernel/3rdparty/flashinfer
sgl-kernel/3rdparty/flashinfer
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
+1
-1
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+4
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+54
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+6
-1
sgl-kernel/tests/test_rotary_embedding.py
sgl-kernel/tests/test_rotary_embedding.py
+175
-95
No files found.
flashinfer
@
4f1f0898
Compare
6e6f38d3
...
4f1f0898
Subproject commit
6e6f38d3534994c34b2c6b09b5b45c8a7b92ffd2
Subproject commit
4f1f08989c71f92df181e346548c2ca48ae6daf5
sgl-kernel/setup.py
View file @
fb11a439
...
...
@@ -94,6 +94,7 @@ sources = [
"3rdparty/flashinfer/csrc/norm.cu"
,
"3rdparty/flashinfer/csrc/sampling.cu"
,
"3rdparty/flashinfer/csrc/renorm.cu"
,
"3rdparty/flashinfer/csrc/rope.cu"
,
]
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
fb11a439
from
sgl_kernel.ops
import
(
apply_rope_with_cos_sin_cache_inplace
,
bmm_fp8
,
custom_dispose
,
custom_reduce
,
...
...
@@ -25,6 +26,7 @@ from sgl_kernel.ops import (
)
__all__
=
[
"apply_rope_with_cos_sin_cache_inplace"
,
"bmm_fp8"
,
"custom_dispose"
,
"custom_reduce"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
View file @
fb11a439
...
...
@@ -98,7 +98,7 @@ void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [nu
int64_t
query_stride
=
query
.
stride
(
-
2
);
int64_t
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
// each block is responsible for one token
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
fb11a439
...
...
@@ -112,3 +112,7 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample
void
top_p_sampling_from_probs
(
at
::
Tensor
probs
,
at
::
Tensor
uniform_samples
,
at
::
Tensor
samples
,
at
::
Tensor
success
,
std
::
optional
<
at
::
Tensor
>
maybe_top_p_arr
,
double
top_p_val
,
bool
deterministic
,
int64_t
cuda_stream
);
void
apply_rope_pos_ids_cos_sin_cache
(
at
::
Tensor
q
,
at
::
Tensor
k
,
at
::
Tensor
q_rope
,
at
::
Tensor
k_rope
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
);
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
fb11a439
...
...
@@ -10,6 +10,60 @@ from sgl_kernel.ops.utils import (
)
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
)
->
None
:
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
with
query
.
device
as
device
:
pos_ids
=
pos_ids
.
int
()
torch
.
ops
.
sgl_kernels
.
apply_rope_pos_ids_cos_sin_cache
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k_rope
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
cos_sin_cache
=
cos_sin_cache
,
pos_ids
=
positions
,
interleave
=
(
not
is_neox
),
cuda_stream
=
_get_cuda_stream
(
device
),
)
def
init_custom_reduce
(
rank_id
,
num_devices
,
rank_data
,
buffers
,
tmp_buffers
,
barrier_in
,
barrier_out
):
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
fb11a439
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
...
...
@@ -116,6 +115,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"
);
m
.
impl
(
"top_p_sampling_from_probs"
,
torch
::
kCUDA
,
&
top_p_sampling_from_probs
);
// apply rope with cos sin cache
m
.
def
(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
}
REGISTER_EXTENSION
(
_kernels
)
sgl-kernel/tests/test_rotary_embedding.py
View file @
fb11a439
from
typing
import
Optional
,
Tuple
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
(
RotaryEmbedding
as
VLLMRotaryEmbedding
,
)
import
torch.nn
as
nn
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
class
SGLRotaryEmbedding
(
VLLMRotaryEmbedding
):
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_
cuda
(
def
forward_
native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
sgl_kernel
import
rotary_embedding
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
# Modification: float32 is required for the rotary embedding to work correctly
query
=
query
.
to
(
torch
.
float32
)
key
=
key
.
to
(
torch
.
float32
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
# Modification: convert to the correct dtype
query
=
query
.
to
(
self
.
dtype
)
key
=
key
.
to
(
self
.
dtype
)
return
query
,
key
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
def
test_rotary_embedding
():
# Test case 1: FP32
def
run_test
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
batch_size
,
seq_len
,
num_heads
,
test_name
,
):
print
(
f
"
\n
Running
{
test_name
}
..."
)
# Initialize both implementations
sgl_rope
=
SGLRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
).
to
(
"cuda"
)
vllm_rope
=
VLLMRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
).
to
(
"cuda"
)
# Regular forward pass
positions
=
torch
.
arange
(
seq_len
,
device
=
"cuda"
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_heads
*
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_heads
*
head_size
,
device
=
"cuda"
,
dtype
=
dtype
class
FlashInferRotaryEmbedding
(
RotaryEmbedding
):
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
head_size
=
self
.
head_size
,
cos_sin_cache
=
self
.
cos_sin_cache
,
is_neox
=
self
.
is_neox_style
,
)
# Make copies for both implementations
query_sgl
=
query
.
clone
()
key_sgl
=
key
.
clone
()
query_vllm
=
query
.
clone
()
key_vllm
=
key
.
clone
()
return
query
,
key
# Run both implementations
query_sgl_out
,
key_sgl_out
=
sgl_rope
.
forward_cuda
(
positions
,
query_sgl
,
key_sgl
)
query_vllm_out
,
key_vllm_out
=
vllm_rope
.
forward_native
(
positions
,
query_vllm
,
key_vllm
)
# Compare outputs
torch
.
testing
.
assert_close
(
query_sgl_out
,
query_vllm_out
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
key_sgl_out
,
key_vllm_out
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
f
"
{
test_name
}
passed!"
)
# Test Case 1: FP32 with larger dimensions
run_test
(
head_size
=
128
,
rotary_dim
=
64
,
max_position
=
4096
,
base
=
10000
,
is_neox_style
=
True
,
dtype
=
torch
.
float32
,
batch_size
=
4
,
seq_len
=
32
,
num_heads
=
8
,
test_name
=
"FP32 Test"
,
@
pytest
.
mark
.
parametrize
(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads"
,
[
(
64
,
64
,
32
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
32
,
32
,
1
,
1
),
(
256
,
128
,
4096
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
4
,
2
),
(
512
,
128
,
311
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
32
,
8
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
16
,
4
),
(
512
,
128
,
311
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
),
],
)
def
test_correctness
(
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
device
:
str
,
batch_size
:
int
,
seq_len
:
int
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
):
rope_ref
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
).
to
(
device
)
rope_flashinfer
=
FlashInferRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
).
to
(
device
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
# Test Case 2: BF16 with smaller dimensions
run_test
(
head_size
=
64
,
rotary_dim
=
32
,
max_position
=
2048
,
base
=
8000
,
is_neox_style
=
True
,
dtype
=
torch
.
bfloat16
,
batch_size
=
2
,
seq_len
=
16
,
num_heads
=
4
,
test_name
=
"BF16 Test"
,
query_ref
,
key_ref
=
query
.
clone
(),
key
.
clone
()
query_flashinfer
,
key_flashinfer
=
query
.
clone
(),
key
.
clone
()
query_ref_out
,
key_ref_out
=
rope_ref
.
forward_native
(
pos_ids
,
query_ref
,
key_ref
)
query_flashinfer_out
,
key_flashinfer_out
=
rope_flashinfer
.
forward_cuda
(
pos_ids
,
query_flashinfer
,
key_flashinfer
)
print
(
query_ref_out
)
print
(
query_flashinfer_out
)
if
__name__
==
"__main__"
:
test_rotary_embedding
()
torch
.
testing
.
assert_close
(
query_ref_out
,
query_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
key_ref_out
,
key_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment