Unverified Commit 63ba630b authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Refactor decoding logprob and add completion_tokens_wo_jump_forward (#189)

parent 6493256b
......@@ -15,10 +15,12 @@ class GenerateReqInput:
sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
# Whether to return logprobs
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
......
......@@ -27,8 +27,12 @@ class Req:
self.input_ids = input_ids
self.output_ids = []
# for accumulated prompt tokens from jump forward
self.orig_prompt_tokens = len(input_ids)
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids)
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
# For vision input
self.pixel_values = None
......
......@@ -424,6 +424,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition
pt = 0
for i, req in enumerate(reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids = [next_token_ids[i]]
req.check_finished()
......@@ -500,6 +501,7 @@ class ModelRpcServer(rpyc.Service):
# Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_tok_id)
req.check_finished()
......@@ -541,15 +543,14 @@ class ModelRpcServer(rpyc.Service):
req.sampling_params.skip_special_tokens
)
# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info = {
"prompt_tokens": req.orig_prompt_tokens,
"prompt_tokens": req.prompt_tokens,
"completion_tokens": len(req.input_ids)
+ len(req.output_ids)
- req.orig_prompt_tokens,
- req.prompt_tokens,
"completion_tokens_wo_jump_forward":
req.completion_tokens_wo_jump_forward
}
if req.return_logprob:
meta_info["prompt_logprob"] = req.logprob
meta_info["token_logprob"] = req.token_logprob
......
......@@ -52,7 +52,7 @@ from sglang.srt.managers.openai_protocol import (
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import alloc_usable_network_port, handle_port_init
from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -96,19 +96,25 @@ async def flush_cache():
)
async def stream_generator(obj):
async def detokenize_logprob_tokens(token_logprobs):
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
async def stream_generator(obj: GenerateReqInput):
async for out in tokenizer_manager.generate_request(obj):
if obj.return_logprob and obj.return_text_in_logprobs:
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
out["meta_info"]["token_logprob"]
)
yield out
async def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()
# Detokenize
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
for token_text, token_logprob in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)
......@@ -132,6 +138,11 @@ async def generate_request(obj: GenerateReqInput):
return StreamingResponse(stream_results(), media_type="text/event-stream")
ret = await tokenizer_manager.generate_request(obj).__anext__()
if obj.return_logprob and obj.return_text_in_logprobs:
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
ret["meta_info"]["token_logprob"]
)
return ret
......@@ -155,6 +166,7 @@ async def v1_completions(raw_request: Request):
"regex": request.regex,
},
return_logprob=request.logprobs is not None,
return_text_in_logprobs=True,
stream=request.stream,
)
adapted_request.post_init()
......@@ -211,6 +223,7 @@ async def v1_completions(raw_request: Request):
# Non-streaming response.
ret = await generate_request(adapted_request)
ret = ret[0] if isinstance(ret, list) else ret
prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment