Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07440f5f
Unverified
Commit
07440f5f
authored
Sep 28, 2025
by
Lianmin Zheng
Committed by
GitHub
Sep 28, 2025
Browse files
Fix FusedSetKVBufferArg in RotaryEmbedding (#11003)
parent
9816989b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
9 deletions
+32
-9
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+25
-6
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
+7
-3
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
07440f5f
...
@@ -27,7 +27,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
...
@@ -27,7 +27,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
else
:
FusedSetKVBufferArg
=
None
if
_use_aiter
:
if
_use_aiter
:
from
aiter.rotary_embedding
import
get_rope
as
aiter_get_rope
from
aiter.rotary_embedding
import
get_rope
as
aiter_get_rope
...
@@ -146,8 +149,13 @@ class RotaryEmbedding(CustomOp):
...
@@ -146,8 +149,13 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
"""A PyTorch-native implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for native implementation"
if
offsets
is
not
None
:
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
positions
=
positions
.
flatten
()
...
@@ -176,12 +184,17 @@ class RotaryEmbedding(CustomOp):
...
@@ -176,12 +184,17 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-npu implementation of forward()."""
"""A PyTorch-npu implementation of forward()."""
import
os
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for npu implementation"
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
else
:
else
:
rotary_mode
=
"half"
rotary_mode
=
"half"
if
self
.
is_neox_style
:
if
self
.
is_neox_style
:
...
@@ -206,8 +219,12 @@ class RotaryEmbedding(CustomOp):
...
@@ -206,8 +219,12 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for cpu implementation"
positions
=
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
positions
=
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
if
_is_cpu_amx_available
:
if
_is_cpu_amx_available
:
return
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
return
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
...
@@ -219,7 +236,9 @@ class RotaryEmbedding(CustomOp):
...
@@ -219,7 +236,9 @@ class RotaryEmbedding(CustomOp):
self
.
is_neox_style
,
self
.
is_neox_style
,
)
)
else
:
else
:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
...
@@ -227,7 +246,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -227,7 +246,7 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
=
None
,
#
Optional[FusedSetKVBufferArg]
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
_is_cuda
and
(
self
.
head_size
in
[
64
,
128
,
256
,
512
]):
if
_is_cuda
and
(
self
.
head_size
in
[
64
,
128
,
256
,
512
]):
apply_rope_with_cos_sin_cache_inplace
(
apply_rope_with_cos_sin_cache_inplace
(
...
...
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
View file @
07440f5f
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
pytest
import
torch
import
torch
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
...
@@ -84,8 +83,13 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -84,8 +83,13 @@ class RotaryEmbedding(torch.nn.Module):
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
"""A PyTorch-native implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for native implementation"
if
offsets
is
not
None
:
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
+
offsets
...
@@ -125,8 +129,8 @@ class FlashInferRotaryEmbedding(RotaryEmbedding):
...
@@ -125,8 +129,8 @@ class FlashInferRotaryEmbedding(RotaryEmbedding):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
apply_rope_with_cos_sin_cache_inplace
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment