Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b5caa22d
Unverified
Commit
b5caa22d
authored
Jan 20, 2025
by
Byron Hsu
Committed by
GitHub
Jan 20, 2025
Browse files
[kernel] port rope cuda kernel to sgl-kernel (#2993)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
73401fd0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
255 additions
and
1 deletion
+255
-1
.gitignore
.gitignore
+3
-0
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
+119
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+6
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+5
-0
sgl-kernel/tests/test_rotary_embedding.py
sgl-kernel/tests/test_rotary_embedding.py
+118
-0
No files found.
.gitignore
View file @
b5caa22d
...
@@ -222,3 +222,6 @@ work_dirs/
...
@@ -222,3 +222,6 @@ work_dirs/
compile_commands.json
compile_commands.json
*.iml
*.iml
# VSCode
.vscode
sgl-kernel/pyproject.toml
View file @
b5caa22d
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
[project]
name
=
"sgl-kernel"
name
=
"sgl-kernel"
version
=
"0.0.2.post1
4
"
version
=
"0.0.2.post1
5
"
description
=
"Kernel Library for SGLang"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
readme
=
"README.md"
requires-python
=
">=3.8"
requires-python
=
">=3.8"
...
...
sgl-kernel/setup.py
View file @
b5caa22d
...
@@ -53,6 +53,7 @@ ext_modules = [
...
@@ -53,6 +53,7 @@ ext_modules = [
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/sgl_kernel_ops.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
],
],
include_dirs
=
include_dirs
,
include_dirs
=
include_dirs
,
extra_compile_args
=
{
extra_compile_args
=
{
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
b5caa22d
...
@@ -6,6 +6,7 @@ from sgl_kernel.ops import (
...
@@ -6,6 +6,7 @@ from sgl_kernel.ops import (
int8_scaled_mm
,
int8_scaled_mm
,
moe_align_block_size
,
moe_align_block_size
,
register_graph_buffers
,
register_graph_buffers
,
rotary_embedding
,
sampling_scaling_penalties
,
sampling_scaling_penalties
,
)
)
...
@@ -18,4 +19,5 @@ __all__ = [
...
@@ -18,4 +19,5 @@ __all__ = [
"sampling_scaling_penalties"
,
"sampling_scaling_penalties"
,
"get_graph_buffer_ipc_meta"
,
"get_graph_buffer_ipc_meta"
,
"register_graph_buffers"
,
"register_graph_buffers"
,
"rotary_embedding"
,
]
]
sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
0 → 100644
View file @
b5caa22d
// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_token_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
__ldg
(
cos_ptr
+
x_index
);
sin
=
__ldg
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
__ldg
(
cos_ptr
+
x_index
/
2
);
sin
=
__ldg
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int64_t
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int64_t
query_stride
=
query
.
stride
(
-
2
);
int64_t
key_stride
=
key
.
stride
(
-
2
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
BFloat16
,
at
::
ScalarType
::
Half
,
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
b5caa22d
...
@@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
...
@@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// rotary embedding
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// trt_reduce
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
...
@@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"sampling_scaling_penalties"
,
&
sampling_scaling_penalties
,
"Sampling scaling penalties (CUDA)"
);
m
.
def
(
"sampling_scaling_penalties"
,
&
sampling_scaling_penalties
,
"Sampling scaling penalties (CUDA)"
);
// int8_scaled_mm
// int8_scaled_mm
m
.
def
(
"int8_scaled_mm"
,
&
int8_scaled_mm
,
"INT8 scaled matmul (CUDA)"
);
m
.
def
(
"int8_scaled_mm"
,
&
int8_scaled_mm
,
"INT8 scaled matmul (CUDA)"
);
// rotary embedding
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Rotary Embedding (CUDA)"
);
}
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
b5caa22d
...
@@ -7,6 +7,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
...
@@ -7,6 +7,7 @@ from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
int8_scaled_mm
as
_int8_scaled_mm
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
moe_align_block_size
as
_moe_align_block_size
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
register_graph_buffers
as
_register_graph_buffers
from
sgl_kernel.ops._kernels
import
rotary_embedding
as
_rotary_embedding
from
sgl_kernel.ops._kernels
import
(
from
sgl_kernel.ops._kernels
import
(
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
)
)
...
@@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
out_dtype
,
out_dtype
,
bias
,
bias
,
)
)
def
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
):
return
_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
)
sgl-kernel/tests/test_rotary_embedding.py
0 → 100644
View file @
b5caa22d
from
typing
import
Optional
,
Tuple
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
(
RotaryEmbedding
as
VLLMRotaryEmbedding
,
)
class
SGLRotaryEmbedding
(
VLLMRotaryEmbedding
):
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
sgl_kernel
import
rotary_embedding
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
return
query
,
key
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
def
test_rotary_embedding
():
# Test case 1: FP32
def
run_test
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
batch_size
,
seq_len
,
num_heads
,
test_name
,
):
print
(
f
"
\n
Running
{
test_name
}
..."
)
# Initialize both implementations
sgl_rope
=
SGLRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
).
to
(
"cuda"
)
vllm_rope
=
VLLMRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
).
to
(
"cuda"
)
# Regular forward pass
positions
=
torch
.
arange
(
seq_len
,
device
=
"cuda"
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_heads
*
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_heads
*
head_size
,
device
=
"cuda"
,
dtype
=
dtype
)
# Make copies for both implementations
query_sgl
=
query
.
clone
()
key_sgl
=
key
.
clone
()
query_vllm
=
query
.
clone
()
key_vllm
=
key
.
clone
()
# Run both implementations
query_sgl_out
,
key_sgl_out
=
sgl_rope
.
forward_cuda
(
positions
,
query_sgl
,
key_sgl
)
query_vllm_out
,
key_vllm_out
=
vllm_rope
.
forward_native
(
positions
,
query_vllm
,
key_vllm
)
# Compare outputs
torch
.
testing
.
assert_close
(
query_sgl_out
,
query_vllm_out
,
rtol
=
1e-3
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
key_sgl_out
,
key_vllm_out
,
rtol
=
1e-3
,
atol
=
1e-3
)
print
(
f
"
{
test_name
}
passed!"
)
# Test Case 1: FP32 with larger dimensions
run_test
(
head_size
=
128
,
rotary_dim
=
64
,
max_position
=
4096
,
base
=
10000
,
is_neox_style
=
True
,
dtype
=
torch
.
float32
,
batch_size
=
4
,
seq_len
=
32
,
num_heads
=
8
,
test_name
=
"FP32 Test"
,
)
# Test Case 2: BF16 with smaller dimensions
run_test
(
head_size
=
64
,
rotary_dim
=
32
,
max_position
=
2048
,
base
=
8000
,
is_neox_style
=
True
,
dtype
=
torch
.
bfloat16
,
batch_size
=
2
,
seq_len
=
16
,
num_heads
=
4
,
test_name
=
"BF16 Test"
,
)
if
__name__
==
"__main__"
:
test_rotary_embedding
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment