Unverified Commit cf79c4fc authored by Oleg Zhelezniak's avatar Oleg Zhelezniak Committed by GitHub
Browse files

feat: sglang guided decoding support (#6620)


Signed-off-by: default avatarjellysnack <oleg.jellysnack@gmail.com>
parent 8409e412
...@@ -353,7 +353,6 @@ async def parse_args(args: list[str]) -> Config: ...@@ -353,7 +353,6 @@ async def parse_args(args: list[str]) -> Config:
server_args.served_model_name = parsed_args.served_model_name server_args.served_model_name = parsed_args.served_model_name
server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False) server_args.enable_metrics = getattr(parsed_args, "enable_metrics", False)
server_args.log_level = getattr(parsed_args, "log_level", "info") server_args.log_level = getattr(parsed_args, "log_level", "info")
server_args.skip_tokenizer_init = True
server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None) server_args.kv_events_config = getattr(parsed_args, "kv_events_config", None)
server_args.tp_size = getattr(parsed_args, "tp_size", 1) server_args.tp_size = getattr(parsed_args, "tp_size", 1)
server_args.dp_size = getattr(parsed_args, "dp_size", 1) server_args.dp_size = getattr(parsed_args, "dp_size", 1)
...@@ -389,15 +388,9 @@ async def parse_args(args: list[str]) -> Config: ...@@ -389,15 +388,9 @@ async def parse_args(args: list[str]) -> Config:
FutureWarning, FutureWarning,
stacklevel=2, stacklevel=2,
) )
logging.info( logging.info("Using SGLang's built in tokenizer")
"Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
)
server_args.skip_tokenizer_init = False
else: else:
logging.info( logging.info("Using dynamo's built in tokenizer")
"Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
)
server_args.skip_tokenizer_init = True
# Derive use_kv_events from server_args.kv_events_config # Derive use_kv_events from server_args.kv_events_config
# Check that kv_events_config exists AND publisher is not "null" ("zmq" or any future publishers) # Check that kv_events_config exists AND publisher is not "null" ("zmq" or any future publishers)
......
...@@ -38,9 +38,9 @@ async def _register_model_with_runtime_config( ...@@ -38,9 +38,9 @@ async def _register_model_with_runtime_config(
""" """
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args) runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
if not server_args.skip_tokenizer_init: if dynamo_args.use_sglang_tokenizer:
logging.warning( logging.warning(
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available" "Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
) )
input_type = ModelInput.Text input_type = ModelInput.Text
# Only override output_type for chat models, not for embeddings # Only override output_type for chat models, not for embeddings
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import inspect import inspect
import json
import logging import logging
import random import random
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -173,13 +174,13 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]): ...@@ -173,13 +174,13 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
self.metrics_publisher = publisher.metrics_publisher self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher self.kv_publisher = publisher.kv_publisher
self.serving_mode = config.serving_mode self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init self.use_sglang_tokenizer = config.dynamo_args.use_sglang_tokenizer
self.enable_trace = config.server_args.enable_trace self.enable_trace = config.server_args.enable_trace
if engine is not None: if engine is not None:
self.input_param_manager = InputParamManager( self.input_param_manager = InputParamManager(
self.engine.tokenizer_manager.tokenizer self.engine.tokenizer_manager.tokenizer
if not self.skip_tokenizer_init if self.use_sglang_tokenizer
else None else None
) )
self._engine_supports_priority = ( self._engine_supports_priority = (
...@@ -430,13 +431,24 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]): ...@@ -430,13 +431,24 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]: def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
request_input = self.input_param_manager.get_input_param( request_input = self.input_param_manager.get_input_param(
request, use_tokenizer=not self.skip_tokenizer_init request, use_tokenizer=self.use_sglang_tokenizer
) )
return { return {
"prompt" if isinstance(request_input, str) else "input_ids": request_input "prompt" if isinstance(request_input, str) else "input_ids": request_input
} }
@staticmethod
def _get_guided_decoding_params(
guided_decoding: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""Extract guided decoding params (e.g. json_schema) for SGLang sampling_params."""
if isinstance(guided_decoding, dict):
json_schema = guided_decoding.get("json")
if json_schema is not None:
return {"json_schema": json.dumps(json_schema)}
return {}
@staticmethod @staticmethod
def _generate_bootstrap_room() -> int: def _generate_bootstrap_room() -> int:
"""Generate a unique bootstrap room ID for disaggregated serving. """Generate a unique bootstrap room ID for disaggregated serving.
......
...@@ -88,7 +88,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -88,7 +88,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
Returns: Returns:
Dict of sampling parameters for SGLang engine. Dict of sampling parameters for SGLang engine.
""" """
if self.skip_tokenizer_init: if not self.use_sglang_tokenizer:
# Token-based request format # Token-based request format
sampling_opts = request.get("sampling_options", {}) sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {}) stop_conditions = request.get("stop_conditions", {})
...@@ -99,6 +99,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -99,6 +99,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"top_k": sampling_opts.get("top_k"), "top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"), "max_new_tokens": stop_conditions.get("max_tokens"),
"ignore_eos": stop_conditions.get("ignore_eos"), "ignore_eos": stop_conditions.get("ignore_eos"),
**self._get_guided_decoding_params(
sampling_opts.get("guided_decoding")
),
} }
else: else:
# OpenAI request format # OpenAI request format
...@@ -107,6 +110,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -107,6 +110,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"top_p": request.get("top_p"), "top_p": request.get("top_p"),
"top_k": request.get("top_k"), "top_k": request.get("top_k"),
"max_new_tokens": request.get("max_tokens"), "max_new_tokens": request.get("max_tokens"),
**self._get_guided_decoding_params(request.get("guided_decoding")),
} }
return {k: v for k, v in param_mapping.items() if v is not None} return {k: v for k, v in param_mapping.items() if v is not None}
...@@ -305,7 +309,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -305,7 +309,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
**self._priority_kwargs(priority), **self._priority_kwargs(priority),
) )
if self.skip_tokenizer_init: if not self.use_sglang_tokenizer:
async for out in self._process_token_stream(decode, context): async for out in self._process_token_stream(decode, context):
yield out yield out
else: else:
...@@ -337,7 +341,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -337,7 +341,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
**logprob_kwargs, **logprob_kwargs,
**self._priority_kwargs(priority), **self._priority_kwargs(priority),
) )
if self.skip_tokenizer_init: if not self.use_sglang_tokenizer:
async for out in self._process_token_stream(agg, context): async for out in self._process_token_stream(agg, context):
yield out yield out
else: else:
......
...@@ -89,7 +89,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler): ...@@ -89,7 +89,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
) )
# Process stream output (token-based or text-based) # Process stream output (token-based or text-based)
if self.skip_tokenizer_init: if not self.use_sglang_tokenizer:
async for out in self._process_token_stream(async_gen, context): async for out in self._process_token_stream(async_gen, context):
yield out yield out
else: else:
......
...@@ -86,6 +86,9 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -86,6 +86,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"top_p": sampling_opts.get("top_p"), "top_p": sampling_opts.get("top_p"),
"top_k": sampling_opts.get("top_k"), "top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"), "max_new_tokens": stop_conditions.get("max_tokens"),
**self._get_guided_decoding_params(
sampling_opts.get("guided_decoding")
),
} }
sampling_params = { sampling_params = {
k: v for k, v in sampling_params.items() if v is not None k: v for k, v in sampling_params.items() if v is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment