Unverified Commit f9be2e9e authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: allow framework tokenization/detokenization (#3134)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-44-124.ec2.internal>
parent 6243bcbe
......@@ -23,6 +23,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
## Table of Contents
- [Feature Support Matrix](#feature-support-matrix)
- [Dynamo SGLang Integration](#dynamo-sglang-integration)
- [Quick Start](#quick-start)
- [Single Node Examples](#run-single-node-examples)
- [Multi-Node and Advanced Examples](#advanced-examples)
......@@ -50,6 +51,31 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| **GB200 Support** | ✅ | |
## Dynamo SGLang Integration
Dynamo SGLang integrates SGLang engines into Dynamo's distributed runtime, enabling advanced features like disaggregated serving, KV-aware routing, and request migration while maintaining full compatibility with SGLang's engine arguments.
### Argument Handling
Dynamo SGLang uses SGLang's native argument parser, so **most SGLang engine arguments work identically**. You can pass any SGLang argument (like `--model-path`, `--tp`, `--trust-remote-code`) directly to `dynamo.sglang`.
#### Dynamo-Specific Arguments
| Argument | Description | Default | SGLang Equivalent |
|----------|-------------|---------|-------------------|
| `--endpoint` | Dynamo endpoint in `dyn://namespace.component.endpoint` format | Auto-generated based on mode | N/A |
| `--migration-limit` | Max times a request can migrate between workers | `0` (disabled) | N/A |
| `--dyn-tool-call-parser` | Tool call parser for structured outputs (takes precedence over `--tool-call-parser`) | `None` | `--tool-call-parser` |
| `--dyn-reasoning-parser` | Reasoning parser for CoT models (takes precedence over `--reasoning-parser`) | `None` | `--reasoning-parser` |
| `--use-sglang-tokenizer` | Use SGLang's tokenizer instead of Dynamo's | `False` | N/A |
#### Tokenizer Behavior
- **Default (`--use-sglang-tokenizer` not set)**: Dynamo handles tokenization and passes `input_ids` to SGLang
- **With `--use-sglang-tokenizer`**: SGLang handles tokenization, Dynamo passes raw prompts
> **Note**: When using `--use-sglang-tokenizer`, only `v1/chat/completions` endpoints are available through Dynamo's frontend.
## SGLang Quick Start
Below we provide a guide that lets you run all of our common deployment patterns on a single node.
......
......@@ -24,5 +24,4 @@ python3 -m dynamo.sglang \
--served-model-name Qwen/Qwen3-0.6B \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init
--trust-remote-code
......@@ -25,7 +25,6 @@ python3 -m dynamo.sglang \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' &
WORKER_PID=$!
......@@ -35,5 +34,4 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}'
......@@ -25,8 +25,9 @@ python3 -m dynamo.sglang \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode prefill \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl &
PREFILL_PID=$!
......@@ -37,6 +38,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode decode \
--disaggregation-bootstrap-port 12345 \
--host 0.0.0.0 \
--disaggregation-transfer-backend nixl
......@@ -30,7 +30,6 @@ python3 -m dynamo.sglang \
--dp-size 2 \
--enable-dp-attention \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode prefill \
--disaggregation-transfer-backend nixl \
--expert-distribution-recorder-mode stat \
......@@ -45,7 +44,6 @@ CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang \
--dp-size 2 \
--enable-dp-attention \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode decode \
--disaggregation-transfer-backend nixl \
--expert-distribution-recorder-mode stat \
......
......@@ -49,6 +49,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"choices": get_reasoning_parser_names(),
"help": "Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
},
"use-sglang-tokenizer": {
"flags": ["--use-sglang-tokenizer"],
"action": "store_true",
"default": False,
"help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend",
},
}
......@@ -63,6 +69,9 @@ class DynamoArgs:
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
# preprocessing options
use_sglang_tokenizer: bool = False
class DisaggregationMode(Enum):
AGGREGATED = "agg"
......@@ -127,13 +136,18 @@ def parse_args(args: list[str]) -> Config:
# Dynamo args
for info in DYNAMO_ARGS.values():
parser.add_argument(
*info["flags"],
type=info["type"],
default=info["default"] if "default" in info else None,
help=info["help"],
choices=info.get("choices", None),
)
kwargs = {
"default": info["default"] if "default" in info else None,
"help": info["help"],
}
if "type" in info:
kwargs["type"] = info["type"]
if "choices" in info:
kwargs["choices"] = info["choices"]
if "action" in info:
kwargs["action"] = info["action"]
parser.add_argument(*info["flags"], **kwargs)
# SGLang args
bootstrap_port = _reserve_disaggregation_bootstrap_port()
......@@ -191,15 +205,20 @@ def parse_args(args: list[str]) -> Config:
migration_limit=parsed_args.migration_limit,
tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser,
use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
)
logging.debug(f"Dynamo args: {dynamo_args}")
server_args = ServerArgs.from_cli_args(parsed_args)
if not server_args.skip_tokenizer_init:
logging.warning(
"When using the dynamo frontend (python3 -m dynamo.frontend), we perform tokenization and detokenization "
"in the frontend. Automatically setting --skip-tokenizer-init to True."
if parsed_args.use_sglang_tokenizer:
logging.info(
"Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
)
server_args.skip_tokenizer_init = False
else:
logging.info(
"Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
)
server_args.skip_tokenizer_init = True
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
TokenIdType = int
......@@ -35,7 +36,6 @@ class SamplingOptions(BaseModel):
class PreprocessedRequest(BaseModel):
token_ids: List[TokenIdType]
batch_token_ids: Optional[List[List[TokenIdType]]] = None
stop_conditions: StopConditions
sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
......@@ -44,6 +44,6 @@ class PreprocessedRequest(BaseModel):
class DisaggPreprocessedRequest(BaseModel):
request: PreprocessedRequest
request: Union[PreprocessedRequest, ChatCompletionRequest]
sampling_params: dict
data_parallel_rank: Optional[int] = None
......@@ -24,10 +24,18 @@ async def register_llm_with_runtime_config(
bool: True if registration succeeded, False if it failed
"""
runtime_config = await _get_runtime_config(engine, dynamo_args)
input_type = ModelInput.Tokens
output_type = ModelType.Chat | ModelType.Completions
if not server_args.skip_tokenizer_init:
logging.warning(
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
)
input_type = ModelInput.Text
output_type = ModelType.Chat
try:
await register_llm(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
input_type,
output_type,
endpoint,
server_args.model_path,
server_args.served_model_name,
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import time
import sglang as sgl
......@@ -41,20 +42,33 @@ class DecodeWorkerHandler(BaseWorkerHandler):
super().cleanup()
def _build_sampling_params(self, request: dict) -> dict:
sampling_params = {}
if request["sampling_options"]["temperature"]:
sampling_params["temperature"] = request["sampling_options"]["temperature"]
if request["sampling_options"]["top_p"]:
sampling_params["top_p"] = request["sampling_options"]["top_p"]
if request["sampling_options"]["top_k"]:
sampling_params["top_k"] = request["sampling_options"]["top_k"]
sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
if request["stop_conditions"]["ignore_eos"]:
sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
return sampling_params
"""Build sampling params depending on request from frontend"""
if self.skip_tokenizer_init:
# Token-based request format
sampling_opts = request.get("sampling_options", {})
stop_conditions = request.get("stop_conditions", {})
param_mapping = {
"temperature": sampling_opts.get("temperature"),
"top_p": sampling_opts.get("top_p"),
"top_k": sampling_opts.get("top_k"),
"max_new_tokens": stop_conditions.get("max_tokens"),
"ignore_eos": stop_conditions.get("ignore_eos"),
}
else:
# OpenAI request format
param_mapping = {
"temperature": request.get("temperature"),
"top_p": request.get("top_p"),
"top_k": request.get("top_k"),
"max_new_tokens": request.get("max_tokens"),
}
return {k: v for k, v in param_mapping.items() if v is not None}
async def generate(self, request: dict):
sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request)
if self.serving_mode == DisaggregationMode.DECODE:
# request the bootstrap info from the target prefill worker
......@@ -74,7 +88,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
raise RuntimeError("No bootstrap info received from prefill worker")
decode = await self.engine.async_generate(
input_ids=request["token_ids"],
**input_param,
sampling_params=sampling_params,
stream=True,
bootstrap_host=bootstrap_info["bootstrap_host"],
......@@ -82,33 +96,69 @@ class DecodeWorkerHandler(BaseWorkerHandler):
bootstrap_room=bootstrap_info["bootstrap_room"],
)
async for out in self._process_stream(decode):
if self.skip_tokenizer_init:
async for out in self._process_token_stream(decode):
yield out
else:
async for out in self._process_text_stream(decode):
yield out
else:
agg = await self.engine.async_generate(
input_ids=request["token_ids"],
**input_param,
sampling_params=sampling_params,
stream=True,
)
async for out in self._process_stream(agg):
if self.skip_tokenizer_init:
async for out in self._process_token_stream(agg):
yield out
else:
async for out in self._process_text_stream(agg):
yield out
async def _process_stream(self, stream_source):
async def _process_token_stream(self, stream_source):
num_output_tokens_so_far = 0
async for res in stream_source:
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
else:
try:
next_total_toks = len(res["output_ids"])
except KeyError:
raise ValueError(
f"Missing 'output_ids' in response. This often happens when using skip_tokenizer_init=False. "
f"If you're using ModelType.CHAT or custom model configurations, you may need to modify "
f"the tokenization/detokenization logic in your handler. Response keys: {list(res.keys())}"
f"Missing 'output_ids' in response. Response keys: {list(res.keys())}"
)
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
num_output_tokens_so_far = next_total_toks
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
yield out
async def _process_text_stream(self, stream_source):
"""Process stream for text input mode"""
count = 0
async for res in stream_source:
index = res.get("index", 0)
text = res.get("text", "")
finish_reason = res["meta_info"]["finish_reason"]
finish_reason_type = finish_reason["type"] if finish_reason else None
next_count = len(text)
delta = text[count:]
choice_data = {
"index": index,
"delta": {"role": "assistant", "content": delta},
"finish_reason": finish_reason_type,
}
response = {
"id": res["meta_info"]["id"],
"created": int(time.time()),
"choices": [choice_data],
"model": self.config.server_args.served_model_name,
"object": "chat.completion.chunk",
}
yield response
count = next_count
......@@ -27,6 +27,7 @@ class BaseWorkerHandler(ABC):
self.kv_publisher = kv_publisher
self.prefill_client = prefill_client
self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
@abstractmethod
async def generate(self, request: str):
......@@ -34,3 +35,15 @@ class BaseWorkerHandler(ABC):
def cleanup(self):
pass
def _get_input_param(self, request: dict) -> dict:
"""Get the appropriate input parameter for SGLang"""
if self.skip_tokenizer_init:
return {"input_ids": request["token_ids"]}
else:
# use sglang's chat templating itself but leave tokenization to the
# interal engine's TokenizerManager
prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template(
request["messages"], tokenize=False, add_generation_prompt=True
)
return {"prompt": prompt}
......@@ -56,8 +56,10 @@ class PrefillWorkerHandler(BaseWorkerHandler):
yield bootstrap_info
input_param = self._get_input_param(request["request"])
results = await self.engine.async_generate(
input_ids=request["request"]["token_ids"],
**input_param,
sampling_params=request["sampling_params"],
stream=True,
bootstrap_host=self.bootstrap_host,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment