"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "988369a01c4bb910a99cde46baa9e2b5b0b69aab"
Unverified Commit 63ba630b authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189)

parent 6493256b
...@@ -15,10 +15,12 @@ class GenerateReqInput: ...@@ -15,10 +15,12 @@ class GenerateReqInput:
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
# The request id # The request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts # Whether to return logprobs
return_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob # The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs: bool = False
# Whether to stream output # Whether to stream output
stream: bool = False stream: bool = False
......
...@@ -27,8 +27,12 @@ class Req: ...@@ -27,8 +27,12 @@ class Req:
self.input_ids = input_ids self.input_ids = input_ids
self.output_ids = [] self.output_ids = []
# for accumulated prompt tokens from jump forward # Since jump forward may retokenize the prompt with partial outputs,
self.orig_prompt_tokens = len(input_ids) # we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids)
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
# For vision input # For vision input
self.pixel_values = None self.pixel_values = None
......
...@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition # Check finish condition
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]] req.output_ids = [next_token_ids[i]]
req.check_finished() req.check_finished()
...@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition # Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)): for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_tok_id) req.output_ids.append(next_tok_id)
req.check_finished() req.check_finished()
...@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service): ...@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
) )
# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info = { meta_info = {
"prompt_tokens": req.orig_prompt_tokens, "prompt_tokens": req.prompt_tokens,
"completion_tokens": len(req.input_ids) "completion_tokens": len(req.input_ids)
+ len(req.output_ids) + len(req.output_ids)
- req.orig_prompt_tokens, - req.prompt_tokens,
"completion_tokens_wo_jump_forward":
req.completion_tokens_wo_jump_forward
} }
if req.return_logprob: if req.return_logprob:
meta_info["prompt_logprob"] = req.logprob meta_info["prompt_logprob"] = req.logprob
meta_info["token_logprob"] = req.token_logprob meta_info["token_logprob"] = req.token_logprob
......
...@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import ( ...@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import alloc_usable_network_port, handle_port_init from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -96,19 +96,25 @@ async def flush_cache(): ...@@ -96,19 +96,25 @@ async def flush_cache():
) )
async def stream_generator(obj): async def detokenize_logprob_tokens(token_logprobs):
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj): async for out in tokenizer_manager.generate_request(obj):
if obj.return_logprob and obj.return_text_in_logprobs:
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
out["meta_info"]["token_logprob"]
)
yield out yield out
async def make_openai_style_logprobs(token_logprobs): async def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs() ret_logprobs = LogProbs()
# Detokenize for token_text, token_logprob in token_logprobs:
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
ret_logprobs.tokens.append(token_text) ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob) ret_logprobs.token_logprobs.append(token_logprob)
...@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput): ...@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
return StreamingResponse(stream_results(), media_type="text/event-stream") return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__() ret = await tokenizer_manager.generate_request(obj).__anext__()
if obj.return_logprob and obj.return_text_in_logprobs:
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
ret["meta_info"]["token_logprob"]
)
return ret return ret
...@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request): ...@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
"regex": request.regex, "regex": request.regex,
}, },
return_logprob=request.logprobs is not None, return_logprob=request.logprobs is not None,
return_text_in_logprobs=True,
stream=request.stream, stream=request.stream,
) )
adapted_request.post_init() adapted_request.post_init()
...@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request): ...@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
# Non-streaming response. # Non-streaming response.
ret = await generate_request(adapted_request) ret = await generate_request(adapted_request)
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"] prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment