Unverified Commit 35bdb485 authored by Shi Shuai's avatar Shi Shuai Committed by GitHub
Browse files

[Feature] Get Token IDs with Engine.generate() (#2636)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent b085e06b
...@@ -181,6 +181,8 @@ class DetokenizerManager: ...@@ -181,6 +181,8 @@ class DetokenizerManager:
finished_reasons=recv_obj.finished_reasons, finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs, output_strs=output_strs,
prompt_tokens=recv_obj.prompt_tokens, prompt_tokens=recv_obj.prompt_tokens,
origin_input_ids=recv_obj.origin_input_ids,
output_ids=recv_obj.output_ids,
completion_tokens=recv_obj.completion_tokens, completion_tokens=recv_obj.completion_tokens,
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_val=recv_obj.input_token_logprobs_val,
......
...@@ -323,7 +323,9 @@ class BatchTokenIDOut: ...@@ -323,7 +323,9 @@ class BatchTokenIDOut:
decoded_texts: List[str] decoded_texts: List[str]
decode_ids: List[int] decode_ids: List[int]
read_offsets: List[int] read_offsets: List[int]
# Only used when `--skip-tokenizer-init` # Only used when --return-token-ids` is set
origin_input_ids: Optional[List[int]]
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
output_ids: Optional[List[int]] output_ids: Optional[List[int]]
# Detokenization configs # Detokenization configs
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
...@@ -354,6 +356,10 @@ class BatchStrOut: ...@@ -354,6 +356,10 @@ class BatchStrOut:
# The output decoded strings # The output decoded strings
output_strs: List[str] output_strs: List[str]
# The token ids
origin_input_ids: Optional[List[int]]
output_ids: Optional[List[int]]
# Token counts # Token counts
prompt_tokens: List[int] prompt_tokens: List[int]
completion_tokens: List[int] completion_tokens: List[int]
......
...@@ -1218,6 +1218,7 @@ class Scheduler: ...@@ -1218,6 +1218,7 @@ class Scheduler:
decode_ids_list = [] decode_ids_list = []
read_offsets = [] read_offsets = []
output_ids = [] output_ids = []
origin_input_ids = []
skip_special_tokens = [] skip_special_tokens = []
spaces_between_special_tokens = [] spaces_between_special_tokens = []
...@@ -1266,8 +1267,14 @@ class Scheduler: ...@@ -1266,8 +1267,14 @@ class Scheduler:
decode_ids, read_offset = req.init_incremental_detokenize() decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids) decode_ids_list.append(decode_ids)
read_offsets.append(read_offset) read_offsets.append(read_offset)
if self.skip_tokenizer_init: if self.skip_tokenizer_init or self.server_args.return_token_ids:
output_ids.append(req.output_ids) output_ids.append(req.output_ids)
else:
output_ids = None
if self.server_args.return_token_ids:
origin_input_ids.append(req.origin_input_ids)
else:
origin_input_ids = None
skip_special_tokens.append(req.sampling_params.skip_special_tokens) skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append( spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
...@@ -1299,6 +1306,7 @@ class Scheduler: ...@@ -1299,6 +1306,7 @@ class Scheduler:
decoded_texts, decoded_texts,
decode_ids_list, decode_ids_list,
read_offsets, read_offsets,
origin_input_ids,
output_ids, output_ids,
skip_special_tokens, skip_special_tokens,
spaces_between_special_tokens, spaces_between_special_tokens,
......
...@@ -663,6 +663,13 @@ class TokenizerManager: ...@@ -663,6 +663,13 @@ class TokenizerManager:
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": meta_info, "meta_info": meta_info,
} }
if self.server_args.return_token_ids:
out_dict.update(
{
"input_ids": recv_obj.origin_input_ids[i],
"output_ids": recv_obj.output_ids[i],
}
)
elif isinstance(recv_obj, BatchTokenIDOut): elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = { out_dict = {
"token_ids": recv_obj.output_ids[i], "token_ids": recv_obj.output_ids[i],
......
...@@ -54,6 +54,7 @@ class ServerArgs: ...@@ -54,6 +54,7 @@ class ServerArgs:
chat_template: Optional[str] = None chat_template: Optional[str] = None
is_embedding: bool = False is_embedding: bool = False
revision: Optional[str] = None revision: Optional[str] = None
return_token_ids: bool = False
# Port for the HTTP server # Port for the HTTP server
host: str = "127.0.0.1" host: str = "127.0.0.1"
...@@ -280,6 +281,12 @@ class ServerArgs: ...@@ -280,6 +281,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request", help="If set, skip init tokenizer and pass input_ids in generate request",
) )
parser.add_argument(
"--return-token-ids",
action="store_true",
default=ServerArgs.return_token_ids,
help="Whether to return token IDs in the output, this may introduce additional overhead.",
)
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
type=str, type=str,
......
...@@ -44,6 +44,7 @@ suites = { ...@@ -44,6 +44,7 @@ suites = {
"test_vision_chunked_prefill.py", "test_vision_chunked_prefill.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py", "test_session_control.py",
"test_engine_token_ids.py",
], ],
"nightly": [ "nightly": [
"test_nightly_gsm8k_eval.py", "test_nightly_gsm8k_eval.py",
......
import unittest
from transformers import AutoTokenizer
import sglang as sgl
class TestEngineTokenIds(unittest.TestCase):
def test_token_ids_in_generate(self):
llm = sgl.Engine(
model_path="meta-llama/Meta-Llama-3.1-8B-Instruct", return_token_ids=True
)
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
outputs = llm.generate(prompts, sampling_params)
# Hugging Face tokenizer has a start token in its output,
# while SGLang only adds next_token_id in output_ids.
# We remove start token in HF output for comparison.
for prompt, output in zip(prompts, outputs):
hf_input_ids = tokenizer.encode(prompt)
self.assertEqual(
output["input_ids"],
hf_input_ids,
f"Input token IDs mismatch for: {prompt}",
)
hf_output_ids = tokenizer.encode(output["text"])[1:] # remove start token
self.assertEqual(
output["output_ids"],
hf_output_ids,
f"Output token IDs mismatch for: {output['text']}",
)
self.assertEqual(
len(output["input_ids"]),
output["meta_info"]["prompt_tokens"],
"Prompt token count mismatch",
)
self.assertEqual(
len(output["output_ids"]),
output["meta_info"]["completion_tokens"],
"Completion token count mismatch",
)
llm.shutdown()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment