Unverified Commit b2cafb7f authored by weireweire's avatar weireweire Committed by GitHub
Browse files

fix(sglang): use incremental streaming output for completions (#7752)


Signed-off-by: default avatarWeiliangl User <weiliangl@login-node.hosted.internal>
Co-authored-by: default avatarWeiliangl User <weiliangl@login-node.hosted.internal>
parent 0ba80f6a
...@@ -18,6 +18,7 @@ fallback and any associated polyfills. ...@@ -18,6 +18,7 @@ fallback and any associated polyfills.
import ipaddress import ipaddress
import logging import logging
import socket import socket
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -98,8 +99,42 @@ except ImportError: ...@@ -98,8 +99,42 @@ except ImportError:
return f"tcp://{self.host}:{self.port}" return f"tcp://{self.host}:{self.port}"
def enable_disjoint_streaming_output(server_args: Any) -> None:
"""
Enable SGLang's disjoint streaming output across ServerArgs field renames.
Covers sglang <= 0.5.x (`stream_output`) and newer releases
(`incremental_streaming_output`).
"""
fields = getattr(type(server_args), "__dataclass_fields__", None)
if isinstance(fields, dict):
if "incremental_streaming_output" in fields:
server_args.incremental_streaming_output = True
return
if "stream_output" in fields:
server_args.stream_output = True
return
raise AttributeError(
"SGLang ServerArgs has neither 'incremental_streaming_output' nor "
"'stream_output'"
)
if hasattr(server_args, "incremental_streaming_output"):
server_args.incremental_streaming_output = True
return
if hasattr(server_args, "stream_output"):
server_args.stream_output = True
return
logger.debug(
"Skipping streaming output compatibility for non-ServerArgs object: %s",
type(server_args).__name__,
)
__all__ = [ __all__ = [
"NetworkAddress", "NetworkAddress",
"enable_disjoint_streaming_output",
"get_local_ip_auto", "get_local_ip_auto",
"get_zmq_socket", "get_zmq_socket",
"_SGLANG_HAS_NETWORK_MODULE", "_SGLANG_HAS_NETWORK_MODULE",
......
...@@ -24,6 +24,7 @@ from dynamo.common.constants import DisaggregationMode ...@@ -24,6 +24,7 @@ from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.runtime import parse_endpoint from dynamo.common.utils.runtime import parse_endpoint
from dynamo.llm import fetch_model from dynamo.llm import fetch_model
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang._compat import enable_disjoint_streaming_output
from dynamo.sglang.backend_args import DynamoSGLangArgGroup, DynamoSGLangConfig from dynamo.sglang.backend_args import DynamoSGLangArgGroup, DynamoSGLangConfig
configure_dynamo_logging() configure_dynamo_logging()
...@@ -374,12 +375,10 @@ async def parse_args(args: list[str]) -> Config: ...@@ -374,12 +375,10 @@ async def parse_args(args: list[str]) -> Config:
) )
# Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new # Dynamo's streaming handlers expect disjoint output_ids from SGLang (only new
# tokens since last output), not cumulative tokens. # tokens since last output), not cumulative tokens. Modern SGLang gates this
# sglang renamed stream_output -> incremental_streaming_output in PR #20614. # behavior behind incremental_streaming_output, while older releases used
if hasattr(ServerArgs, "incremental_streaming_output"): # stream_output.
server_args.incremental_streaming_output = True enable_disjoint_streaming_output(server_args)
else:
server_args.stream_output = True
if dynamo_config.use_sglang_tokenizer: if dynamo_config.use_sglang_tokenizer:
warnings.warn( warnings.warn(
......
...@@ -284,6 +284,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -284,6 +284,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
// Update prompt_tokens from worker if provided (e.g., for embeddings) // Update prompt_tokens from worker if provided (e.g., for embeddings)
self.usage.prompt_tokens = completion_usage.prompt_tokens; self.usage.prompt_tokens = completion_usage.prompt_tokens;
// Propagate completion token details if provided
if let Some(completion_details) = completion_usage.completion_tokens_details.as_ref() {
self.usage.completion_tokens_details = Some(completion_details.clone());
}
// Propagate prompt token details if provided // Propagate prompt token details if provided
if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() { if let Some(prompt_details) = completion_usage.prompt_tokens_details.as_ref() {
self.usage.prompt_tokens_details = Some(prompt_details.clone()); self.usage.prompt_tokens_details = Some(prompt_details.clone());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment