Unverified Commit f9be2e9e authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: allow framework tokenization/detokenization (#3134)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-44-124.ec2.internal>
parent 6243bcbe
...@@ -23,6 +23,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) ...@@ -23,6 +23,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
## Table of Contents ## Table of Contents
- [Feature Support Matrix](#feature-support-matrix) - [Feature Support Matrix](#feature-support-matrix)
- [Dynamo SGLang Integration](#dynamo-sglang-integration)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Single Node Examples](#run-single-node-examples) - [Single Node Examples](#run-single-node-examples)
- [Multi-Node and Advanced Examples](#advanced-examples) - [Multi-Node and Advanced Examples](#advanced-examples)
...@@ -50,6 +51,31 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) ...@@ -50,6 +51,31 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| **GB200 Support** | ✅ | | | **GB200 Support** | ✅ | |
## Dynamo SGLang Integration
Dynamo SGLang integrates SGLang engines into Dynamo's distributed runtime, enabling advanced features like disaggregated serving, KV-aware routing, and request migration while maintaining full compatibility with SGLang's engine arguments.
### Argument Handling
Dynamo SGLang uses SGLang's native argument parser, so **most SGLang engine arguments work identically**. You can pass any SGLang argument (like `--model-path`, `--tp`, `--trust-remote-code`) directly to `dynamo.sglang`.
#### Dynamo-Specific Arguments
| Argument | Description | Default | SGLang Equivalent |
|----------|-------------|---------|-------------------|
| `--endpoint` | Dynamo endpoint in `dyn://namespace.component.endpoint` format | Auto-generated based on mode | N/A |
| `--migration-limit` | Max times a request can migrate between workers | `0` (disabled) | N/A |
| `--dyn-tool-call-parser` | Tool call parser for structured outputs (takes precedence over `--tool-call-parser`) | `None` | `--tool-call-parser` |
| `--dyn-reasoning-parser` | Reasoning parser for CoT models (takes precedence over `--reasoning-parser`) | `None` | `--reasoning-parser` |
| `--use-sglang-tokenizer` | Use SGLang's tokenizer instead of Dynamo's | `False` | N/A |
#### Tokenizer Behavior
- **Default (`--use-sglang-tokenizer` not set)**: Dynamo handles tokenization and passes `input_ids` to SGLang
- **With `--use-sglang-tokenizer`**: SGLang handles tokenization, Dynamo passes raw prompts
> **Note**: When using `--use-sglang-tokenizer`, only `v1/chat/completions` endpoints are available through Dynamo's frontend.
## SGLang Quick Start ## SGLang Quick Start
Below we provide a guide that lets you run all of our common deployment patterns on a single node. Below we provide a guide that lets you run all of our common deployment patterns on a single node.
......
...@@ -24,5 +24,4 @@ python3 -m dynamo.sglang \ ...@@ -24,5 +24,4 @@ python3 -m dynamo.sglang \
--served-model-name Qwen/Qwen3-0.6B \ --served-model-name Qwen/Qwen3-0.6B \
--page-size 16 \ --page-size 16 \
--tp 1 \ --tp 1 \
--trust-remote-code \ --trust-remote-code
--skip-tokenizer-init
...@@ -25,7 +25,6 @@ python3 -m dynamo.sglang \ ...@@ -25,7 +25,6 @@ python3 -m dynamo.sglang \
--page-size 16 \ --page-size 16 \
--tp 1 \ --tp 1 \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' & --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' &
WORKER_PID=$! WORKER_PID=$!
...@@ -35,5 +34,4 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ ...@@ -35,5 +34,4 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--page-size 16 \ --page-size 16 \
--tp 1 \ --tp 1 \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}'
...@@ -25,8 +25,9 @@ python3 -m dynamo.sglang \ ...@@ -25,8 +25,9 @@ python3 -m dynamo.sglang \
--page-size 16 \ --page-size 16 \
--tp 1 \ --tp 1 \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode prefill \ --disaggregation-mode prefill \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl & --disaggregation-transfer-backend nixl &
PREFILL_PID=$! PREFILL_PID=$!
...@@ -37,6 +38,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ ...@@ -37,6 +38,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--page-size 16 \ --page-size 16 \
--tp 1 \ --tp 1 \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode decode \ --disaggregation-mode decode \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl --disaggregation-transfer-backend nixl
...@@ -30,7 +30,6 @@ python3 -m dynamo.sglang \ ...@@ -30,7 +30,6 @@ python3 -m dynamo.sglang \
--dp-size 2 \ --dp-size 2 \
--enable-dp-attention \ --enable-dp-attention \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode prefill \ --disaggregation-mode prefill \
--disaggregation-transfer-backend nixl \ --disaggregation-transfer-backend nixl \
--expert-distribution-recorder-mode stat \ --expert-distribution-recorder-mode stat \
...@@ -45,7 +44,6 @@ CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang \ ...@@ -45,7 +44,6 @@ CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang \
--dp-size 2 \ --dp-size 2 \
--enable-dp-attention \ --enable-dp-attention \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode decode \ --disaggregation-mode decode \
--disaggregation-transfer-backend nixl \ --disaggregation-transfer-backend nixl \
--expert-distribution-recorder-mode stat \ --expert-distribution-recorder-mode stat \
......
...@@ -49,6 +49,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -49,6 +49,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"choices": get_reasoning_parser_names(), "choices": get_reasoning_parser_names(),
"help": "Reasoning parser name for the model. If not specified, no reasoning parsing is performed.", "help": "Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
}, },
"use-sglang-tokenizer": {
"flags": ["--use-sglang-tokenizer"],
"action": "store_true",
"default": False,
"help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend",
},
} }
...@@ -63,6 +69,9 @@ class DynamoArgs: ...@@ -63,6 +69,9 @@ class DynamoArgs:
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = None
# preprocessing options
use_sglang_tokenizer: bool = False
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
AGGREGATED = "agg" AGGREGATED = "agg"
...@@ -127,13 +136,18 @@ def parse_args(args: list[str]) -> Config: ...@@ -127,13 +136,18 @@ def parse_args(args: list[str]) -> Config:
# Dynamo args # Dynamo args
for info in DYNAMO_ARGS.values(): for info in DYNAMO_ARGS.values():
parser.add_argument( kwargs = {
*info["flags"], "default": info["default"] if "default" in info else None,
type=info["type"], "help": info["help"],
default=info["default"] if "default" in info else None, }
help=info["help"], if "type" in info:
choices=info.get("choices", None), kwargs["type"] = info["type"]
) if "choices" in info:
kwargs["choices"] = info["choices"]
if "action" in info:
kwargs["action"] = info["action"]
parser.add_argument(*info["flags"], **kwargs)
# SGLang args # SGLang args
bootstrap_port = _reserve_disaggregation_bootstrap_port() bootstrap_port = _reserve_disaggregation_bootstrap_port()
...@@ -191,15 +205,20 @@ def parse_args(args: list[str]) -> Config: ...@@ -191,15 +205,20 @@ def parse_args(args: list[str]) -> Config:
migration_limit=parsed_args.migration_limit, migration_limit=parsed_args.migration_limit,
tool_call_parser=tool_call_parser, tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
) )
logging.debug(f"Dynamo args: {dynamo_args}") logging.debug(f"Dynamo args: {dynamo_args}")
server_args = ServerArgs.from_cli_args(parsed_args) server_args = ServerArgs.from_cli_args(parsed_args)
if not server_args.skip_tokenizer_init: if parsed_args.use_sglang_tokenizer:
logging.warning( logging.info(
"When using the dynamo frontend (python3 -m dynamo.frontend), we perform tokenization and detokenization " "Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
"in the frontend. Automatically setting --skip-tokenizer-init to True." )
server_args.skip_tokenizer_init = False
else:
logging.info(
"Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
) )
server_args.skip_tokenizer_init = True server_args.skip_tokenizer_init = True
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Optional from typing import List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
TokenIdType = int TokenIdType = int
...@@ -35,7 +36,6 @@ class SamplingOptions(BaseModel): ...@@ -35,7 +36,6 @@ class SamplingOptions(BaseModel):
class PreprocessedRequest(BaseModel): class PreprocessedRequest(BaseModel):
token_ids: List[TokenIdType] token_ids: List[TokenIdType]
batch_token_ids: Optional[List[List[TokenIdType]]] = None
stop_conditions: StopConditions stop_conditions: StopConditions
sampling_options: SamplingOptions sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list) eos_token_ids: List[TokenIdType] = Field(default_factory=list)
...@@ -44,6 +44,6 @@ class PreprocessedRequest(BaseModel): ...@@ -44,6 +44,6 @@ class PreprocessedRequest(BaseModel):
class DisaggPreprocessedRequest(BaseModel): class DisaggPreprocessedRequest(BaseModel):
request: PreprocessedRequest request: Union[PreprocessedRequest, ChatCompletionRequest]
sampling_params: dict sampling_params: dict
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
...@@ -24,10 +24,18 @@ async def register_llm_with_runtime_config( ...@@ -24,10 +24,18 @@ async def register_llm_with_runtime_config(
bool: True if registration succeeded, False if it failed bool: True if registration succeeded, False if it failed
""" """
runtime_config = await _get_runtime_config(engine, dynamo_args) runtime_config = await _get_runtime_config(engine, dynamo_args)
input_type = ModelInput.Tokens
output_type = ModelType.Chat | ModelType.Completions
if not server_args.skip_tokenizer_init:
logging.warning(
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
)
input_type = ModelInput.Text
output_type = ModelType.Chat
try: try:
await register_llm( await register_llm(
ModelInput.Tokens, input_type,
ModelType.Chat | ModelType.Completions, output_type,
endpoint, endpoint,
server_args.model_path, server_args.model_path,
server_args.served_model_name, server_args.served_model_name,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
import time
import sglang as sgl import sglang as sgl
...@@ -41,20 +42,33 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -41,20 +42,33 @@ class DecodeWorkerHandler(BaseWorkerHandler):
super().cleanup() super().cleanup()
def _build_sampling_params(self, request: dict) -> dict: def _build_sampling_params(self, request: dict) -> dict:
sampling_params = {} """Build sampling params depending on request from frontend"""
if request["sampling_options"]["temperature"]: if self.skip_tokenizer_init:
sampling_params["temperature"] = request["sampling_options"]["temperature"] # Token-based request format
if request["sampling_options"]["top_p"]: sampling_opts = request.get("sampling_options", {})
sampling_params["top_p"] = request["sampling_options"]["top_p"] stop_conditions = request.get("stop_conditions", {})
if request["sampling_options"]["top_k"]:
sampling_params["top_k"] = request["sampling_options"]["top_k"] param_mapping = {
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"] "temperature": sampling_opts.get("temperature"),
if request["stop_conditions"]["ignore_eos"]: "top_p": sampling_opts.get("top_p"),
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"] "top_k": sampling_opts.get("top_k"),
return sampling_params "max_new_tokens": stop_conditions.get("max_tokens"),
"ignore_eos": stop_conditions.get("ignore_eos"),
}
else:
# OpenAI request format
param_mapping = {
"temperature": request.get("temperature"),
"top_p": request.get("top_p"),
"top_k": request.get("top_k"),
"max_new_tokens": request.get("max_tokens"),
}
return {k: v for k, v in param_mapping.items() if v is not None}
async def generate(self, request: dict): async def generate(self, request: dict):
sampling_params = self._build_sampling_params(request) sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker # request the bootstrap info from the target prefill worker
...@@ -74,7 +88,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -74,7 +88,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
raise RuntimeError("No bootstrap info received from prefill worker") raise RuntimeError("No bootstrap info received from prefill worker")
decode = await self.engine.async_generate( decode = await self.engine.async_generate(
input_ids=request["token_ids"], **input_param,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_host=bootstrap_info["bootstrap_host"],
...@@ -82,33 +96,69 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -82,33 +96,69 @@ class DecodeWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_info["bootstrap_room"], bootstrap_room=bootstrap_info["bootstrap_room"],
) )
async for out in self._process_stream(decode): if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode):
yield out
else:
async for out in self._process_text_stream(decode):
yield out yield out
else: else:
agg = await self.engine.async_generate( agg = await self.engine.async_generate(
input_ids=request["token_ids"], **input_param,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
) )
async for out in self._process_stream(agg): if self.skip_tokenizer_init:
async for out in self._process_token_stream(agg):
yield out
else:
async for out in self._process_text_stream(agg):
yield out yield out
async def _process_stream(self, stream_source): async def _process_token_stream(self, stream_source):
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for res in stream_source: async for res in stream_source:
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
else:
try: try:
next_total_toks = len(res["output_ids"]) next_total_toks = len(res["output_ids"])
except KeyError: except KeyError:
raise ValueError( raise ValueError(
f"Missing 'output_ids' in response. This often happens when using skip_tokenizer_init=False. " f"Missing 'output_ids' in response. Response keys: {list(res.keys())}"
f"If you're using ModelType.CHAT or custom model configurations, you may need to modify "
f"the tokenization/detokenization logic in your handler. Response keys: {list(res.keys())}"
) )
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]} out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
num_output_tokens_so_far = next_total_toks num_output_tokens_so_far = next_total_toks
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
yield out yield out
async def _process_text_stream(self, stream_source):
"""Process stream for text input mode"""
count = 0
async for res in stream_source:
index = res.get("index", 0)
text = res.get("text", "")
finish_reason = res["meta_info"]["finish_reason"]
finish_reason_type = finish_reason["type"] if finish_reason else None
next_count = len(text)
delta = text[count:]
choice_data = {
"index": index,
"delta": {"role": "assistant", "content": delta},
"finish_reason": finish_reason_type,
}
response = {
"id": res["meta_info"]["id"],
"created": int(time.time()),
"choices": [choice_data],
"model": self.config.server_args.served_model_name,
"object": "chat.completion.chunk",
}
yield response
count = next_count
...@@ -27,6 +27,7 @@ class BaseWorkerHandler(ABC): ...@@ -27,6 +27,7 @@ class BaseWorkerHandler(ABC):
self.kv_publisher = kv_publisher self.kv_publisher = kv_publisher
self.prefill_client = prefill_client self.prefill_client = prefill_client
self.serving_mode = config.serving_mode self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
@abstractmethod @abstractmethod
async def generate(self, request: str): async def generate(self, request: str):
...@@ -34,3 +35,15 @@ class BaseWorkerHandler(ABC): ...@@ -34,3 +35,15 @@ class BaseWorkerHandler(ABC):
def cleanup(self): def cleanup(self):
pass pass
def _get_input_param(self, request: dict) -> dict:
"""Get the appropriate input parameter for SGLang"""
if self.skip_tokenizer_init:
return {"input_ids": request["token_ids"]}
else:
# use sglang's chat templating itself but leave tokenization to the
# interal engine's TokenizerManager
prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template(
request["messages"], tokenize=False, add_generation_prompt=True
)
return {"prompt": prompt}
...@@ -56,8 +56,10 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -56,8 +56,10 @@ class PrefillWorkerHandler(BaseWorkerHandler):
yield bootstrap_info yield bootstrap_info
input_param = self._get_input_param(request["request"])
results = await self.engine.async_generate( results = await self.engine.async_generate(
input_ids=request["request"]["token_ids"], **input_param,
sampling_params=request["sampling_params"], sampling_params=request["sampling_params"],
stream=True, stream=True,
bootstrap_host=self.bootstrap_host, bootstrap_host=self.bootstrap_host,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment