Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07440f5f
"fair_dev/testing/testing.py" did not exist on "49a198c99cdf61cf869ced2dc1e4e8b69926ceed"
Unverified
Commit
07440f5f
authored
Sep 28, 2025
by
Lianmin Zheng
Committed by
GitHub
Sep 28, 2025
Browse files
Fix FusedSetKVBufferArg in RotaryEmbedding (#11003)
parent
9816989b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
9 deletions
+32
-9
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+25
-6
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
+7
-3
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
07440f5f
...
...
@@ -27,7 +27,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
if
_is_cuda
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
else
:
FusedSetKVBufferArg
=
None
if
_use_aiter
:
from
aiter.rotary_embedding
import
get_rope
as
aiter_get_rope
...
...
@@ -146,8 +149,13 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for native implementation"
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
...
...
@@ -176,12 +184,17 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-npu implementation of forward()."""
import
os
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for npu implementation"
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
else
:
rotary_mode
=
"half"
if
self
.
is_neox_style
:
...
...
@@ -206,8 +219,12 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for cpu implementation"
positions
=
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
if
_is_cpu_amx_available
:
return
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
...
...
@@ -219,7 +236,9 @@ class RotaryEmbedding(CustomOp):
self
.
is_neox_style
,
)
else
:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
def
forward_cuda
(
self
,
...
...
@@ -227,7 +246,7 @@ class RotaryEmbedding(CustomOp):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
=
None
,
#
Optional[FusedSetKVBufferArg]
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
_is_cuda
and
(
self
.
head_size
in
[
64
,
128
,
256
,
512
]):
apply_rope_with_cos_sin_cache_inplace
(
...
...
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
View file @
07440f5f
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
pytest
import
torch
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
...
...
@@ -84,8 +83,13 @@ class RotaryEmbedding(torch.nn.Module):
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
assert
(
fused_set_kv_buffer_arg
is
None
),
"fused_set_kv_buffer_arg is not supported for native implementation"
if
offsets
is
not
None
:
positions
=
positions
+
offsets
...
...
@@ -125,8 +129,8 @@ class FlashInferRotaryEmbedding(RotaryEmbedding):
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment