Unverified Commit d809906e authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add embedding support to sgl backend (#3427)


Signed-off-by: default avatarkrishung5 <krish@nvidia.com>
parent 4c888bf4
......@@ -175,6 +175,26 @@ cd $DYNAMO_HOME/components/backends/sglang
./launch/agg_router.sh
```
### Aggregated Serving with Embeddings
Here's an example that uses the [Qwen/Qwen3-Embedding-4B](https://huggingface.co/Qwen/Qwen3-Embedding-4B) model.
```bash
cd $DYNAMO_HOME/components/backends/sglang
./launch/agg_embed.sh
```
Send the following request to verify your deployment:
```bash
curl localhost:8000/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-Embedding-4B",
"input": "Hello, world!"
}'
```
### Disaggregated serving
<details>
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID 2>/dev/null || true
wait $DYNAMO_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run clear_namespace
python3 -m dynamo.sglang.utils.clear_namespace --namespace dynamo
# run ingress
python3 -m dynamo.frontend --http-port=8000 &
DYNAMO_PID=$!
# run worker
python3 -m dynamo.sglang \
--embedding-worker \
--model-path Qwen/Qwen3-Embedding-4B \
--served-model-name Qwen/Qwen3-Embedding-4B \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--use-sglang-tokenizer
......@@ -79,6 +79,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": False,
"help": "Run as multimodal worker component for LLM inference with multimodal data",
},
"embedding-worker": {
"flags": ["--embedding-worker"],
"action": "store_true",
"default": False,
"help": "Run as embedding worker component (Dynamo flag, also sets SGLang's --is-embedding)",
},
}
......@@ -102,6 +108,9 @@ class DynamoArgs:
multimodal_encode_worker: bool = False
multimodal_worker: bool = False
# embedding options
embedding_worker: bool = False
class DisaggregationMode(Enum):
AGGREGATED = "agg"
......@@ -221,9 +230,15 @@ def parse_args(args: list[str]) -> Config:
# otherwise fall back to default endpoints
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
# If --embedding-worker is set, also set SGLang's --is-embedding flag
if parsed_args.embedding_worker:
parsed_args.is_embedding = True
endpoint = parsed_args.endpoint
if endpoint is None:
if (
if parsed_args.embedding_worker:
endpoint = f"dyn://{namespace}.backend.generate"
elif (
hasattr(parsed_args, "disaggregation_mode")
and parsed_args.disaggregation_mode == "prefill"
):
......@@ -291,6 +306,7 @@ def parse_args(args: list[str]) -> Config:
multimodal_processor=parsed_args.multimodal_processor,
multimodal_encode_worker=parsed_args.multimodal_encode_worker,
multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker,
)
logging.debug(f"Dynamo args: {dynamo_args}")
......
......@@ -9,7 +9,7 @@ import sys
import sglang as sgl
import uvloop
from dynamo.llm import ModelInput
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
......@@ -21,6 +21,7 @@ from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import (
DecodeWorkerHandler,
EmbeddingWorkerHandler,
MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
......@@ -44,7 +45,9 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = parse_args(sys.argv[1:])
if config.dynamo_args.multimodal_processor:
if config.dynamo_args.embedding_worker:
await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config)
elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
......@@ -158,6 +161,63 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler.cleanup()
async def init_embedding(runtime: DistributedRuntime, config: Config):
"""Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
output_type=ModelType.Embedding,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve embedding endpoints: {e}")
raise
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......
......@@ -43,6 +43,18 @@ class PreprocessedRequest(BaseModel):
annotations: List[str] = Field(default_factory=list)
EmbeddingInput = Union[str, List[str], List[int], List[List[int]]]
class EmbeddingRequest(BaseModel):
model: str
input: EmbeddingInput
user: Optional[str] = None
dimensions: Optional[
int
] = None # only supported in text-embedding-3 and later models from OpenAI
class DisaggPreprocessedRequest(BaseModel):
request: Union[PreprocessedRequest, ChatCompletionRequest]
sampling_params: dict
......
......@@ -19,6 +19,7 @@ async def _register_llm_with_runtime_config(
server_args: ServerArgs,
dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions,
) -> bool:
"""Register LLM with the Dynamo runtime.
......@@ -28,19 +29,23 @@ async def _register_llm_with_runtime_config(
server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration.
input_type: Expected model input type. Defaults to ModelInput.Tokens.
output_type: Expected model output type. Defaults to ModelType.Chat | ModelType.Completions.
Returns:
True if registration succeeded, False otherwise.
"""
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
input_type = input_type
output_type = ModelType.Chat | ModelType.Completions
if not server_args.skip_tokenizer_init:
logging.warning(
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
)
input_type = ModelInput.Text
# Only override output_type for chat models, not for embeddings
if output_type != ModelType.Embedding:
output_type = ModelType.Chat
try:
await register_llm(
input_type,
......@@ -134,6 +139,7 @@ async def register_llm_with_readiness_gate(
server_args: ServerArgs,
dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions,
readiness_gate: Optional[asyncio.Event] = None,
) -> None:
"""Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success.
......@@ -144,6 +150,7 @@ async def register_llm_with_readiness_gate(
server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration.
input_type: Expected model input type. Defaults to ModelInput.Tokens.
output_type: Expected model output type. Defaults to ModelType.Chat | ModelType.Completions.
readiness_gate: Optional event to signal when registration completes.
Raises:
......@@ -155,6 +162,7 @@ async def register_llm_with_readiness_gate(
server_args,
dynamo_args,
input_type,
output_type,
)
if not registration_success:
logging.error("Model registration failed; shutting down")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .decode_handler import DecodeWorkerHandler
# Embedding handlers
from .embedding import EmbeddingWorkerHandler
# Base handlers
from .handler_base import BaseWorkerHandler
# LLM handlers
from .llm import DecodeWorkerHandler, PrefillWorkerHandler
# Multimodal handlers
from .multimodal_encode_worker_handler import MultimodalEncodeWorkerHandler
from .multimodal_processor_handler import MultimodalProcessorHandler
from .multimodal_worker_handler import (
from .multimodal import (
MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
MultimodalWorkerHandler,
)
from .prefill_handler import PrefillWorkerHandler
__all__ = [
"BaseWorkerHandler",
# LLM handlers
"DecodeWorkerHandler",
"PrefillWorkerHandler",
"MultimodalProcessorHandler",
# Embedding handlers
"EmbeddingWorkerHandler",
# Multimodal handlers
"MultimodalEncodeWorkerHandler",
"MultimodalWorkerHandler",
"MultimodalPrefillWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .embedding_handler import EmbeddingWorkerHandler
__all__ = [
"EmbeddingWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
import sglang as sgl
from dynamo._core import Component
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import EmbeddingRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class EmbeddingWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
):
super().__init__(component, engine, config, publisher)
logging.info("Embedding worker handler initialized")
def cleanup(self):
self.engine.shutdown()
logging.info("Engine shutdown")
super().cleanup()
async def generate(self, request: dict):
logging.debug(f"Embedding request: {request}")
# Parse the embedding request - should only receive EmbeddingRequest format
embedding_request = EmbeddingRequest(**request)
# Handle different input types
if isinstance(embedding_request.input, str):
prompt = embedding_request.input
elif isinstance(embedding_request.input, list):
prompt = embedding_request.input
else:
raise TypeError(f"Invalid input type: {type(embedding_request.input)}")
result = await self.engine.async_encode(prompt=prompt)
# Transform the response to OpenAI format
response = self._transform_response(result, embedding_request.model)
yield response
def _transform_response(self, ret, model_name):
"""Transform SGLang response to OpenAI embedding format"""
if not isinstance(ret, list):
ret = [ret]
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
embedding_objects.append(
{
"object": "embedding",
"embedding": ret_item["embedding"],
"index": idx,
}
)
prompt_tokens += ret_item.get("meta_info", {}).get("prompt_tokens", 0)
return {
"object": "list",
"data": embedding_objects,
"model": model_name,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .decode_handler import DecodeWorkerHandler
from .prefill_handler import PrefillWorkerHandler
__all__ = [
"DecodeWorkerHandler",
"PrefillWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .encode_worker_handler import MultimodalEncodeWorkerHandler
from .processor_handler import MultimodalProcessorHandler
from .worker_handler import MultimodalPrefillWorkerHandler, MultimodalWorkerHandler
__all__ = [
"MultimodalEncodeWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalWorkerHandler",
"MultimodalPrefillWorkerHandler",
]
......@@ -388,6 +388,18 @@ impl ModelWatcher {
.context("add_completions_model")?;
tracing::info!("Completions is ready");
}
} else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
// Case: Text + Embeddings
let push_router = PushRouter::<
NvCreateEmbeddingRequest,
Annotated<NvCreateEmbeddingResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_embeddings_model(&model_entry.name, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat
let push_router = PushRouter::<
......
......@@ -17,6 +17,8 @@ from tests.utils.payload_builder import (
chat_payload,
chat_payload_default,
completion_payload_default,
embedding_payload,
embedding_payload_default,
)
logger = logging.getLogger(__name__)
......@@ -116,6 +118,39 @@ sglang_configs = {
)
],
),
"embedding_agg": SGLangConfig(
name="embedding_agg",
directory=sglang_dir,
script_name="agg_embed.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-Embedding-4B",
delayed_start=0,
timeout=180,
models_port=8000,
request_payloads=[
# Test default payload with multiple inputs
embedding_payload_default(
repeat_count=2,
expected_response=["Generated 2 embeddings with dimension"],
),
# Test single string input
embedding_payload(
input_text="Hello, world!",
repeat_count=1,
expected_response=["Generated 1 embeddings with dimension"],
),
# Test multiple string inputs
embedding_payload(
input_text=[
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming technology.",
"Natural language processing enables computers to understand text.",
],
repeat_count=1,
expected_response=["Generated 3 embeddings with dimension"],
),
],
),
}
......
......@@ -11,10 +11,12 @@ import os
QWEN = "Qwen/Qwen3-0.6B"
LLAMA = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # on an l4 gpu, must limit --max-seq-len, otherwise it will not fit
QWEN_EMBEDDING = "Qwen/Qwen3-Embedding-4B"
TEST_MODELS = [
QWEN,
LLAMA,
QWEN_EMBEDDING,
]
# Env-driven defaults for specific test groups
......
......@@ -4,7 +4,12 @@
from typing import Any, Dict, List, Optional, Union
from tests.utils.client import send_request
from tests.utils.payloads import ChatPayload, CompletionPayload, MetricsPayload
from tests.utils.payloads import (
ChatPayload,
CompletionPayload,
EmbeddingPayload,
MetricsPayload,
)
# Common default text prompt used across tests
TEXT_PROMPT = "Tell me a short joke about AI."
......@@ -123,6 +128,47 @@ def completion_payload(
)
def embedding_payload_default(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
) -> EmbeddingPayload:
return EmbeddingPayload(
body={
"input": ["The sky is blue.", "Machine learning is fascinating."],
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or ["Generated 2 embeddings with dimension"],
)
def embedding_payload(
input_text: Union[str, List[str]],
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
) -> EmbeddingPayload:
# Normalize input to list for consistent processing
if isinstance(input_text, str):
input_list = [input_text]
expected_count = 1
else:
input_list = input_text
expected_count = len(input_text)
return EmbeddingPayload(
body={
"input": input_list,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or [f"Generated {expected_count} embeddings with dimension"],
)
# Build small request-based health checks for chat and completions
# these should only be used as a last resort. Generally want to use an actual health check
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment