Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
63ba630b
"vscode:/vscode.git/clone" did not exist on "717d15719c713fd3ee9ab0d8eb3d98116758036e"
Unverified
Commit
63ba630b
authored
Feb 15, 2024
by
Cody Yu
Committed by
GitHub
Feb 15, 2024
Browse files
Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189)
parent
6493256b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
15 deletions
+35
-15
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-1
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+6
-2
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+6
-5
python/sglang/srt/server.py
python/sglang/srt/server.py
+20
-7
No files found.
python/sglang/srt/managers/io_struct.py
View file @
63ba630b
...
...
@@ -15,10 +15,12 @@ class GenerateReqInput:
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# The request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether return logprobs
of the prompts
# Whether
to
return logprobs
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# The start location of the prompt for return_logprob
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs
:
bool
=
False
# Whether to stream output
stream
:
bool
=
False
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
63ba630b
...
...
@@ -27,8 +27,12 @@ class Req:
self
.
input_ids
=
input_ids
self
.
output_ids
=
[]
# for accumulated prompt tokens from jump forward
self
.
orig_prompt_tokens
=
len
(
input_ids
)
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self
.
prompt_tokens
=
len
(
input_ids
)
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self
.
completion_tokens_wo_jump_forward
=
0
# For vision input
self
.
pixel_values
=
None
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
63ba630b
...
...
@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
check_finished
()
...
...
@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_tok_id
)
req
.
check_finished
()
...
...
@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
req
.
sampling_params
.
skip_special_tokens
)
# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info
=
{
"prompt_tokens"
:
req
.
orig_
prompt_tokens
,
"prompt_tokens"
:
req
.
prompt_tokens
,
"completion_tokens"
:
len
(
req
.
input_ids
)
+
len
(
req
.
output_ids
)
-
req
.
orig_prompt_tokens
,
-
req
.
prompt_tokens
,
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
}
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"token_logprob"
]
=
req
.
token_logprob
...
...
python/sglang/srt/server.py
View file @
63ba630b
...
...
@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
alloc_usable_network_port
,
handle_port_init
from
sglang.srt.utils
import
handle_port_init
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -96,19 +96,25 @@ async def flush_cache():
)
async
def
stream_generator
(
obj
):
async
def
detokenize_logprob_tokens
(
token_logprobs
):
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
return
[(
text
,
logprob
)
for
text
,
(
_
,
logprob
)
in
zip
(
token_texts
,
token_logprobs
)]
async
def
stream_generator
(
obj
:
GenerateReqInput
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
if
obj
.
return_logprob
and
obj
.
return_text_in_logprobs
:
out
[
"meta_info"
][
"token_logprob"
]
=
await
detokenize_logprob_tokens
(
out
[
"meta_info"
][
"token_logprob"
]
)
yield
out
async
def
make_openai_style_logprobs
(
token_logprobs
):
ret_logprobs
=
LogProbs
()
# Detokenize
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
for
token_text
,
(
_
,
token_logprob
)
in
zip
(
token_texts
,
token_logprobs
):
for
token_text
,
token_logprob
in
token_logprobs
:
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
token_logprob
)
...
...
@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
if
obj
.
return_logprob
and
obj
.
return_text_in_logprobs
:
ret
[
"meta_info"
][
"token_logprob"
]
=
await
detokenize_logprob_tokens
(
ret
[
"meta_info"
][
"token_logprob"
]
)
return
ret
...
...
@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
"regex"
:
request
.
regex
,
},
return_logprob
=
request
.
logprobs
is
not
None
,
return_text_in_logprobs
=
True
,
stream
=
request
.
stream
,
)
adapted_request
.
post_init
()
...
...
@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
ret
=
await
generate_request
(
adapted_request
)
ret
=
ret
[
0
]
if
isinstance
(
ret
,
list
)
else
ret
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment