Unverified Commit 748fee6b authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

feat(sglang): enforce stream_output=True for optimal streaming performance (#5510)



This ensures that only new tokens are returned by sglang which avoids the overhead from creating copies of the entire token sequences per each iteration. These copies can become a bottleneck particularly for long sequence lengths and large concurrency counts.
Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent 7b639e76
......@@ -491,6 +491,12 @@ async def parse_args(args: list[str]) -> Config:
# contain code to download a model, it should only parse the args.
server_args = ServerArgs.from_cli_args(parsed_args)
# Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new
# tokens since last output), not cumulative tokens. When stream_output=True,
# SGLang sends disjoint segments which Dynamo passes through directly.
# Force stream_output=True for optimal streaming performance.
server_args.stream_output = True
if parsed_args.use_sglang_tokenizer:
logging.info(
"Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
......
......@@ -183,6 +183,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process token-based stream output.
With stream_output=True (enforced by Dynamo), SGLang sends disjoint segments
containing only new tokens since the last output. We pass these through directly.
Args:
stream_source: Async generator from engine.async_generate.
context: Context object for cancellation handling.
......@@ -190,8 +193,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
Yields:
Dict with token_ids and optional finish_reason.
"""
num_output_tokens_so_far = 0
# Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context):
......@@ -213,6 +214,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if finish_reason:
out["finish_reason"] = finish_reason["type"]
# With stream_output=True, output_ids contains only new tokens (disjoint)
output_ids = res.get("output_ids", [])
# If request is not finished yet, but there are no outputs, return an error.
if not output_ids and not finish_reason:
......@@ -220,9 +222,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
yield {"finish_reason": "error", "token_ids": []}
break
next_total_toks = len(output_ids)
out["token_ids"] = output_ids[num_output_tokens_so_far:]
num_output_tokens_so_far = next_total_toks
# Pass through disjoint token segments directly
out["token_ids"] = output_ids
if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"]
completion_tokens = res["meta_info"]["completion_tokens"]
......
......@@ -133,30 +133,26 @@ class StreamProcessor:
@staticmethod
async def process_sglang_stream(stream_source) -> AsyncIterator[str]:
"""Process SGLang stream output following backend pattern"""
num_output_tokens_so_far = 0
"""Process SGLang stream output.
With stream_output=True (enforced by Dynamo), SGLang sends disjoint segments
containing only new tokens since the last output. We pass these through directly.
"""
try:
async for res in stream_source:
try:
next_total_toks = len(res["output_ids"])
# Return incremental tokens
# With stream_output=True, output_ids contains only new tokens (disjoint)
output = {
"token_ids": res["output_ids"][num_output_tokens_so_far:],
"token_ids": res["output_ids"],
"text": res.get("text", ""),
"finished": False,
}
num_output_tokens_so_far = next_total_toks
# Check for finish reason
finish_reason = res.get("meta_info", {}).get("finish_reason")
if finish_reason:
output.update(
{
"token_ids": res["output_ids"][
num_output_tokens_so_far:
],
"finish_reason": finish_reason.get("type", "stop"),
"finished": True,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment