Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
80e2c4a8
Unverified
Commit
80e2c4a8
authored
Nov 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 18, 2024
Browse files
Fix chunked prefill with output logprob (#2083)
parent
66318ffe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
9 deletions
+38
-9
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+5
-1
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+33
-8
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
80e2c4a8
...
@@ -302,7 +302,11 @@ class PrefillAdder:
...
@@ -302,7 +302,11 @@ class PrefillAdder:
if
(
if
(
self
.
rem_chunk_tokens
is
None
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
or
input_tokens
<=
self
.
rem_chunk_tokens
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
)
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
and
req
.
logprob_start_len
!=
len
(
req
.
origin_input_ids
)
-
1
)
):
):
# Non-chunked prefill
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
test/srt/test_srt_endpoint.py
View file @
80e2c4a8
"""
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_
parallel_sample
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_
logprob_with_chunked_prefill
"""
"""
import
json
import
json
...
@@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase):
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
for
i
,
res
in
enumerate
(
response_json
):
for
i
,
res
in
enumerate
(
response_json
):
assert
res
[
"meta_info"
][
"prompt_tokens"
]
==
logprob_start_len
+
1
+
len
(
self
.
assertEqual
(
res
[
"meta_info"
][
"input_token_logprobs"
]
res
[
"meta_info"
][
"prompt_tokens"
],
logprob_start_len
+
1
+
len
(
res
[
"meta_info"
][
"input_token_logprobs"
]),
)
)
assert
prompts
[
i
].
endswith
(
assert
prompts
[
i
].
endswith
(
""
.
join
([
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"input_token_logprobs"
]])
""
.
join
([
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"input_token_logprobs"
]])
)
)
assert
res
[
"meta_info"
][
"completion_tokens"
]
==
new_tokens
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
assert
len
(
res
[
"meta_info"
][
"output_token_logprobs"
])
==
new_tokens
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
res
[
"text"
]
==
""
.
join
(
self
.
assertEqual
(
[
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"output_token_logprobs"
]]
res
[
"text"
],
""
.
join
([
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"output_token_logprobs"
]]),
)
)
def
test_logprob_with_chunked_prefill
(
self
):
new_tokens
=
4
prompts
=
"I have a very good idea on this. "
*
8000
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prompts
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
new_tokens
,
},
"return_logprob"
:
True
,
"logprob_start_len"
:
-
1
,
},
)
response_json
=
response
.
json
()
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
res
=
response_json
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_get_memory_pool_size
(
self
):
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
assert
isi
nstance
(
response
.
json
(),
int
)
self
.
assert
IsI
nstance
(
response
.
json
(),
int
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment