Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
988d0a4b
Unverified
Commit
988d0a4b
authored
Jan 27, 2025
by
Byron Hsu
Committed by
GitHub
Jan 28, 2025
Browse files
[kernel] Use sgl_kernel rope (#3169)
Co-authored-by:
zhyncs
<
me@zhyncs.com
>
parent
81262c7b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
16 deletions
+45
-16
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+28
-12
test/srt/test_session_control.py
test/srt/test_session_control.py
+17
-4
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
988d0a4b
...
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.custom_op_util
import
register_custom_op
from
sglang.srt.layers.custom_op_util
import
register_custom_op
from
sglang.srt.utils
import
is_cuda_available
_is_cuda_available
=
is_cuda_available
()
if
_is_cuda_available
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
...
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if
not
_is_cuda_available
:
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
...
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
...
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
if
_is_cuda_available
:
apply_rope_with_cos_sin_cache_inplace
(
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
positions
=
positions
,
ops
.
rotary_embedding
(
query
=
query
,
positions
,
key
=
key
,
query
,
head_size
=
self
.
head_size
,
key
,
cos_sin_cache
=
self
.
cos_sin_cache
,
self
.
head_size
,
is_neox
=
self
.
is_neox_style
,
self
.
cos_sin_cache
,
)
self
.
is_neox_style
,
else
:
)
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
return
query
,
key
return
query
,
key
def
forward_xpu
(
def
forward_xpu
(
...
...
test/srt/test_session_control.py
View file @
988d0a4b
...
@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
...
@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids
[
i
]
=
chunks_ids
[
i
][
1
:]
chunks_ids
[
i
]
=
chunks_ids
[
i
][
1
:]
# 1. using session control
# 1. using session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
session_id
=
requests
.
post
(
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
json
=
{
"capacity_of_str_len"
:
1000
},
...
@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
...
@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
print
(
outputs_from_session
)
print
(
outputs_from_session
)
print
(
"outputs from normal queries:"
)
print
(
"outputs from normal queries:"
)
print
(
outputs_normal
)
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
assert
(
outputs_from_session
==
outputs_normal
),
f
"outputs_from_session:
{
outputs_from_session
}
, outputs_normal:
{
outputs_normal
}
"
async
def
async_generate
(
self
,
payload
):
async
def
async_generate
(
self
,
payload
):
url
=
self
.
base_url
+
"/generate"
url
=
self
.
base_url
+
"/generate"
...
@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
...
@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
chunks_ids
[
i
]
=
chunks_ids
[
i
][
1
:]
chunks_ids
[
i
]
=
chunks_ids
[
i
][
1
:]
# 1. using session control
# 1. using session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
session_id
=
requests
.
post
(
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
json
=
{
"capacity_of_str_len"
:
1000
},
...
@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
...
@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
else
:
else
:
# 2. not using session control
# 2. not using session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
output_ids
=
tokenizer
.
encode
(
gen_so_far
)
output_ids
=
tokenizer
.
encode
(
gen_so_far
)
if
output_ids
[
0
]
==
tokenizer
.
bos_token_id
:
if
output_ids
[
0
]
==
tokenizer
.
bos_token_id
:
output_ids
=
output_ids
[
1
:]
output_ids
=
output_ids
[
1
:]
...
@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
...
@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
output_no_session
=
response
[
"text"
]
output_no_session
=
response
[
"text"
]
print
(
"second request output without session:"
)
print
(
"second request output without session:"
)
print
(
output_no_session
)
print
(
output_no_session
)
assert
second_output
==
output_no_session
assert
(
second_output
==
output_no_session
),
f
"second_output:
{
second_output
}
, output_no_session:
{
output_no_session
}
"
def
test_session_control_backtrack_with_abort
(
self
):
def
test_session_control_backtrack_with_abort
(
self
):
asyncio
.
run
(
self
.
run_session_control_backtrack_with_abort
(
replace
=
True
))
asyncio
.
run
(
self
.
run_session_control_backtrack_with_abort
(
replace
=
True
))
...
@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
...
@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
assert
len
(
x
)
==
len
(
chunks_per_step
[
0
])
assert
len
(
x
)
==
len
(
chunks_per_step
[
0
])
# 1. using session control
# 1. using session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
session_id
=
requests
.
post
(
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
json
=
{
"capacity_of_str_len"
:
1000
},
...
@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
...
@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
print
(
outputs_from_session
)
print
(
outputs_from_session
)
print
(
"====== outputs from normal queries: ======="
)
print
(
"====== outputs from normal queries: ======="
)
print
(
outputs_normal
)
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
assert
(
outputs_from_session
==
outputs_normal
),
f
"outputs_from_session:
{
outputs_from_session
}
, outputs_normal:
{
outputs_normal
}
"
def
test_session_control_with_branching
(
self
):
def
test_session_control_with_branching
(
self
):
root_prompt
=
"First, let me explain in one sentence about AI"
root_prompt
=
"First, let me explain in one sentence about AI"
...
@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
...
@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
gen_len
=
32
gen_len
=
32
# 1. using session control
# 1. using session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
session_id
=
requests
.
post
(
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
json
=
{
"capacity_of_str_len"
:
1000
},
...
@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
...
@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
print
(
outputs_from_session
)
print
(
outputs_from_session
)
print
(
"outputs from normal queries:"
)
print
(
"outputs from normal queries:"
)
print
(
outputs_normal
)
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
assert
(
outputs_from_session
==
outputs_normal
),
f
"outputs_from_session:
{
outputs_from_session
}
, outputs_normal:
{
outputs_normal
}
"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment