Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
63ba630b
Unverified
Commit
63ba630b
authored
Feb 15, 2024
by
Cody Yu
Committed by
GitHub
Feb 15, 2024
Browse files
Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189)
parent
6493256b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
15 deletions
+35
-15
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+3
-1
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+6
-2
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+6
-5
python/sglang/srt/server.py
python/sglang/srt/server.py
+20
-7
No files found.
python/sglang/srt/managers/io_struct.py
View file @
63ba630b
...
@@ -15,10 +15,12 @@ class GenerateReqInput:
...
@@ -15,10 +15,12 @@ class GenerateReqInput:
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# The request id
# The request id
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether return logprobs
of the prompts
# Whether
to
return logprobs
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
# The start location of the prompt for return_logprob
# The start location of the prompt for return_logprob
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs
:
bool
=
False
# Whether to stream output
# Whether to stream output
stream
:
bool
=
False
stream
:
bool
=
False
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
63ba630b
...
@@ -27,8 +27,12 @@ class Req:
...
@@ -27,8 +27,12 @@ class Req:
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
output_ids
=
[]
self
.
output_ids
=
[]
# for accumulated prompt tokens from jump forward
# Since jump forward may retokenize the prompt with partial outputs,
self
.
orig_prompt_tokens
=
len
(
input_ids
)
# we maintain the original prompt length to report the correct usage.
self
.
prompt_tokens
=
len
(
input_ids
)
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self
.
completion_tokens_wo_jump_forward
=
0
# For vision input
# For vision input
self
.
pixel_values
=
None
self
.
pixel_values
=
None
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
63ba630b
...
@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition
# Check finish condition
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
output_ids
=
[
next_token_ids
[
i
]]
req
.
check_finished
()
req
.
check_finished
()
...
@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
.
append
(
next_tok_id
)
req
.
output_ids
.
append
(
next_tok_id
)
req
.
check_finished
()
req
.
check_finished
()
...
@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
...
@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
req
.
sampling_params
.
skip_special_tokens
req
.
sampling_params
.
skip_special_tokens
)
)
# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info
=
{
meta_info
=
{
"prompt_tokens"
:
req
.
orig_
prompt_tokens
,
"prompt_tokens"
:
req
.
prompt_tokens
,
"completion_tokens"
:
len
(
req
.
input_ids
)
"completion_tokens"
:
len
(
req
.
input_ids
)
+
len
(
req
.
output_ids
)
+
len
(
req
.
output_ids
)
-
req
.
orig_prompt_tokens
,
-
req
.
prompt_tokens
,
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
}
}
if
req
.
return_logprob
:
if
req
.
return_logprob
:
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"prompt_logprob"
]
=
req
.
logprob
meta_info
[
"token_logprob"
]
=
req
.
token_logprob
meta_info
[
"token_logprob"
]
=
req
.
token_logprob
...
...
python/sglang/srt/server.py
View file @
63ba630b
...
@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
...
@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
alloc_usable_network_port
,
handle_port_init
from
sglang.srt.utils
import
handle_port_init
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -96,19 +96,25 @@ async def flush_cache():
...
@@ -96,19 +96,25 @@ async def flush_cache():
)
)
async
def
stream_generator
(
obj
):
async
def
detokenize_logprob_tokens
(
token_logprobs
):
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
return
[(
text
,
logprob
)
for
text
,
(
_
,
logprob
)
in
zip
(
token_texts
,
token_logprobs
)]
async
def
stream_generator
(
obj
:
GenerateReqInput
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
async
for
out
in
tokenizer_manager
.
generate_request
(
obj
):
if
obj
.
return_logprob
and
obj
.
return_text_in_logprobs
:
out
[
"meta_info"
][
"token_logprob"
]
=
await
detokenize_logprob_tokens
(
out
[
"meta_info"
][
"token_logprob"
]
)
yield
out
yield
out
async
def
make_openai_style_logprobs
(
token_logprobs
):
async
def
make_openai_style_logprobs
(
token_logprobs
):
ret_logprobs
=
LogProbs
()
ret_logprobs
=
LogProbs
()
# Detokenize
for
token_text
,
token_logprob
in
token_logprobs
:
token_ids
=
[
tid
for
tid
,
_
in
token_logprobs
]
token_texts
=
await
tokenizer_manager
.
detokenize
(
DetokenizeReqInput
(
token_ids
))
for
token_text
,
(
_
,
token_logprob
)
in
zip
(
token_texts
,
token_logprobs
):
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
tokens
.
append
(
token_text
)
ret_logprobs
.
token_logprobs
.
append
(
token_logprob
)
ret_logprobs
.
token_logprobs
.
append
(
token_logprob
)
...
@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
...
@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
stream_results
(),
media_type
=
"text/event-stream"
)
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
ret
=
await
tokenizer_manager
.
generate_request
(
obj
).
__anext__
()
if
obj
.
return_logprob
and
obj
.
return_text_in_logprobs
:
ret
[
"meta_info"
][
"token_logprob"
]
=
await
detokenize_logprob_tokens
(
ret
[
"meta_info"
][
"token_logprob"
]
)
return
ret
return
ret
...
@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
...
@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
},
},
return_logprob
=
request
.
logprobs
is
not
None
,
return_logprob
=
request
.
logprobs
is
not
None
,
return_text_in_logprobs
=
True
,
stream
=
request
.
stream
,
stream
=
request
.
stream
,
)
)
adapted_request
.
post_init
()
adapted_request
.
post_init
()
...
@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
...
@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
# Non-streaming response.
ret
=
await
generate_request
(
adapted_request
)
ret
=
await
generate_request
(
adapted_request
)
ret
=
ret
[
0
]
if
isinstance
(
ret
,
list
)
else
ret
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
prompt_tokens
=
ret
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
completion_tokens
=
ret
[
"meta_info"
][
"completion_tokens"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment