Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
80e2c4a8
"vscode:/vscode.git/clone" did not exist on "9fea3bc470a6e65ec91652ba504f56eb6e2ff0bd"
Unverified
Commit
80e2c4a8
authored
Nov 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 18, 2024
Browse files
Fix chunked prefill with output logprob (#2083)
parent
66318ffe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
9 deletions
+38
-9
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+5
-1
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+33
-8
No files found.
python/sglang/srt/managers/schedule_policy.py
View file @
80e2c4a8
...
@@ -302,7 +302,11 @@ class PrefillAdder:
...
@@ -302,7 +302,11 @@ class PrefillAdder:
if
(
if
(
self
.
rem_chunk_tokens
is
None
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
or
input_tokens
<=
self
.
rem_chunk_tokens
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
)
or
(
req
.
return_logprob
and
req
.
normalized_prompt_logprob
is
None
and
req
.
logprob_start_len
!=
len
(
req
.
origin_input_ids
)
-
1
)
):
):
# Non-chunked prefill
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
...
test/srt/test_srt_endpoint.py
View file @
80e2c4a8
"""
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_
parallel_sample
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_
logprob_with_chunked_prefill
"""
"""
import
json
import
json
...
@@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -116,22 +116,47 @@ class TestSRTEndpoint(unittest.TestCase):
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
for
i
,
res
in
enumerate
(
response_json
):
for
i
,
res
in
enumerate
(
response_json
):
assert
res
[
"meta_info"
][
"prompt_tokens"
]
==
logprob_start_len
+
1
+
len
(
self
.
assertEqual
(
res
[
"meta_info"
][
"input_token_logprobs"
]
res
[
"meta_info"
][
"prompt_tokens"
],
logprob_start_len
+
1
+
len
(
res
[
"meta_info"
][
"input_token_logprobs"
]),
)
)
assert
prompts
[
i
].
endswith
(
assert
prompts
[
i
].
endswith
(
""
.
join
([
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"input_token_logprobs"
]])
""
.
join
([
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"input_token_logprobs"
]])
)
)
assert
res
[
"meta_info"
][
"completion_tokens"
]
==
new_tokens
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
assert
len
(
res
[
"meta_info"
][
"output_token_logprobs"
])
==
new_tokens
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
res
[
"text"
]
==
""
.
join
(
self
.
assertEqual
(
[
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"output_token_logprobs"
]]
res
[
"text"
],
""
.
join
([
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"output_token_logprobs"
]]),
)
)
def
test_logprob_with_chunked_prefill
(
self
):
new_tokens
=
4
prompts
=
"I have a very good idea on this. "
*
8000
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prompts
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
new_tokens
,
},
"return_logprob"
:
True
,
"logprob_start_len"
:
-
1
,
},
)
response_json
=
response
.
json
()
print
(
json
.
dumps
(
response_json
,
indent
=
2
))
res
=
response_json
self
.
assertEqual
(
res
[
"meta_info"
][
"completion_tokens"
],
new_tokens
)
self
.
assertEqual
(
len
(
res
[
"meta_info"
][
"output_token_logprobs"
]),
new_tokens
)
def
test_get_memory_pool_size
(
self
):
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
assert
isi
nstance
(
response
.
json
(),
int
)
self
.
assert
IsI
nstance
(
response
.
json
(),
int
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment