Unverified Commit 748fee6b authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

feat(sglang): enforce stream_output=True for optimal streaming performance (#5510)



This ensures that only new tokens are returned by sglang which avoids the overhead from creating copies of the entire token sequences per each iteration. These copies can become a bottleneck particularly for long sequence lengths and large concurrency counts.
Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent 7b639e76
...@@ -491,6 +491,12 @@ async def parse_args(args: list[str]) -> Config: ...@@ -491,6 +491,12 @@ async def parse_args(args: list[str]) -> Config:
# contain code to download a model, it should only parse the args. # contain code to download a model, it should only parse the args.
server_args = ServerArgs.from_cli_args(parsed_args) server_args = ServerArgs.from_cli_args(parsed_args)
# Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new
# tokens since last output), not cumulative tokens. When stream_output=True,
# SGLang sends disjoint segments which Dynamo passes through directly.
# Force stream_output=True for optimal streaming performance.
server_args.stream_output = True
if parsed_args.use_sglang_tokenizer: if parsed_args.use_sglang_tokenizer:
logging.info( logging.info(
"Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False" "Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
......
...@@ -183,6 +183,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -183,6 +183,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
"""Process token-based stream output. """Process token-based stream output.
With stream_output=True (enforced by Dynamo), SGLang sends disjoint segments
containing only new tokens since the last output. We pass these through directly.
Args: Args:
stream_source: Async generator from engine.async_generate. stream_source: Async generator from engine.async_generate.
context: Context object for cancellation handling. context: Context object for cancellation handling.
...@@ -190,8 +193,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -190,8 +193,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
Yields: Yields:
Dict with token_ids and optional finish_reason. Dict with token_ids and optional finish_reason.
""" """
num_output_tokens_so_far = 0
# Use Future pattern for request ID - will be set when first response arrives # Use Future pattern for request ID - will be set when first response arrives
request_id_future = asyncio.Future() request_id_future = asyncio.Future()
async with self._cancellation_monitor(request_id_future, context): async with self._cancellation_monitor(request_id_future, context):
...@@ -213,6 +214,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -213,6 +214,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if finish_reason: if finish_reason:
out["finish_reason"] = finish_reason["type"] out["finish_reason"] = finish_reason["type"]
# With stream_output=True, output_ids contains only new tokens (disjoint)
output_ids = res.get("output_ids", []) output_ids = res.get("output_ids", [])
# If request is not finished yet, but there are no outputs, return an error. # If request is not finished yet, but there are no outputs, return an error.
if not output_ids and not finish_reason: if not output_ids and not finish_reason:
...@@ -220,9 +222,8 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -220,9 +222,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
yield {"finish_reason": "error", "token_ids": []} yield {"finish_reason": "error", "token_ids": []}
break break
next_total_toks = len(output_ids) # Pass through disjoint token segments directly
out["token_ids"] = output_ids[num_output_tokens_so_far:] out["token_ids"] = output_ids
num_output_tokens_so_far = next_total_toks
if finish_reason: if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"] input_tokens = res["meta_info"]["prompt_tokens"]
completion_tokens = res["meta_info"]["completion_tokens"] completion_tokens = res["meta_info"]["completion_tokens"]
......
...@@ -133,30 +133,26 @@ class StreamProcessor: ...@@ -133,30 +133,26 @@ class StreamProcessor:
@staticmethod @staticmethod
async def process_sglang_stream(stream_source) -> AsyncIterator[str]: async def process_sglang_stream(stream_source) -> AsyncIterator[str]:
"""Process SGLang stream output following backend pattern""" """Process SGLang stream output.
num_output_tokens_so_far = 0
With stream_output=True (enforced by Dynamo), SGLang sends disjoint segments
containing only new tokens since the last output. We pass these through directly.
"""
try: try:
async for res in stream_source: async for res in stream_source:
try: try:
next_total_toks = len(res["output_ids"]) # With stream_output=True, output_ids contains only new tokens (disjoint)
# Return incremental tokens
output = { output = {
"token_ids": res["output_ids"][num_output_tokens_so_far:], "token_ids": res["output_ids"],
"text": res.get("text", ""), "text": res.get("text", ""),
"finished": False, "finished": False,
} }
num_output_tokens_so_far = next_total_toks
# Check for finish reason # Check for finish reason
finish_reason = res.get("meta_info", {}).get("finish_reason") finish_reason = res.get("meta_info", {}).get("finish_reason")
if finish_reason: if finish_reason:
output.update( output.update(
{ {
"token_ids": res["output_ids"][
num_output_tokens_so_far:
],
"finish_reason": finish_reason.get("type", "stop"), "finish_reason": finish_reason.get("type", "stop"),
"finished": True, "finished": True,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment