Unverified Commit d809906e authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add embedding support to sgl backend (#3427)


Signed-off-by: default avatarkrishung5 <krish@nvidia.com>
parent 4c888bf4
...@@ -175,6 +175,26 @@ cd $DYNAMO_HOME/components/backends/sglang ...@@ -175,6 +175,26 @@ cd $DYNAMO_HOME/components/backends/sglang
./launch/agg_router.sh ./launch/agg_router.sh
``` ```
### Aggregated Serving with Embeddings
Here's an example that uses the [Qwen/Qwen3-Embedding-4B](https://huggingface.co/Qwen/Qwen3-Embedding-4B) model.
```bash
cd $DYNAMO_HOME/components/backends/sglang
./launch/agg_embed.sh
```
Send the following request to verify your deployment:
```bash
curl localhost:8000/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-Embedding-4B",
"input": "Hello, world!"
}'
```
### Disaggregated serving ### Disaggregated serving
<details> <details>
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $DYNAMO_PID 2>/dev/null || true
wait $DYNAMO_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# run clear_namespace
python3 -m dynamo.sglang.utils.clear_namespace --namespace dynamo
# run ingress
python3 -m dynamo.frontend --http-port=8000 &
DYNAMO_PID=$!
# run worker
python3 -m dynamo.sglang \
--embedding-worker \
--model-path Qwen/Qwen3-Embedding-4B \
--served-model-name Qwen/Qwen3-Embedding-4B \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--use-sglang-tokenizer
...@@ -79,6 +79,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = { ...@@ -79,6 +79,12 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": False, "default": False,
"help": "Run as multimodal worker component for LLM inference with multimodal data", "help": "Run as multimodal worker component for LLM inference with multimodal data",
}, },
"embedding-worker": {
"flags": ["--embedding-worker"],
"action": "store_true",
"default": False,
"help": "Run as embedding worker component (Dynamo flag, also sets SGLang's --is-embedding)",
},
} }
...@@ -102,6 +108,9 @@ class DynamoArgs: ...@@ -102,6 +108,9 @@ class DynamoArgs:
multimodal_encode_worker: bool = False multimodal_encode_worker: bool = False
multimodal_worker: bool = False multimodal_worker: bool = False
# embedding options
embedding_worker: bool = False
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
AGGREGATED = "agg" AGGREGATED = "agg"
...@@ -221,9 +230,15 @@ def parse_args(args: list[str]) -> Config: ...@@ -221,9 +230,15 @@ def parse_args(args: list[str]) -> Config:
# otherwise fall back to default endpoints # otherwise fall back to default endpoints
namespace = os.environ.get("DYN_NAMESPACE", "dynamo") namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
# If --embedding-worker is set, also set SGLang's --is-embedding flag
if parsed_args.embedding_worker:
parsed_args.is_embedding = True
endpoint = parsed_args.endpoint endpoint = parsed_args.endpoint
if endpoint is None: if endpoint is None:
if ( if parsed_args.embedding_worker:
endpoint = f"dyn://{namespace}.backend.generate"
elif (
hasattr(parsed_args, "disaggregation_mode") hasattr(parsed_args, "disaggregation_mode")
and parsed_args.disaggregation_mode == "prefill" and parsed_args.disaggregation_mode == "prefill"
): ):
...@@ -291,6 +306,7 @@ def parse_args(args: list[str]) -> Config: ...@@ -291,6 +306,7 @@ def parse_args(args: list[str]) -> Config:
multimodal_processor=parsed_args.multimodal_processor, multimodal_processor=parsed_args.multimodal_processor,
multimodal_encode_worker=parsed_args.multimodal_encode_worker, multimodal_encode_worker=parsed_args.multimodal_encode_worker,
multimodal_worker=parsed_args.multimodal_worker, multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker,
) )
logging.debug(f"Dynamo args: {dynamo_args}") logging.debug(f"Dynamo args: {dynamo_args}")
......
...@@ -9,7 +9,7 @@ import sys ...@@ -9,7 +9,7 @@ import sys
import sglang as sgl import sglang as sgl
import uvloop import uvloop
from dynamo.llm import ModelInput from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.args import Config, DisaggregationMode, parse_args
...@@ -21,6 +21,7 @@ from dynamo.sglang.publisher import setup_sgl_metrics ...@@ -21,6 +21,7 @@ from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_readiness_gate from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import ( from dynamo.sglang.request_handlers import (
DecodeWorkerHandler, DecodeWorkerHandler,
EmbeddingWorkerHandler,
MultimodalEncodeWorkerHandler, MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler, MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler, MultimodalProcessorHandler,
...@@ -44,7 +45,9 @@ async def worker(runtime: DistributedRuntime): ...@@ -44,7 +45,9 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers will trigger a graceful shutdown of the runtime") logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = parse_args(sys.argv[1:]) config = parse_args(sys.argv[1:])
if config.dynamo_args.multimodal_processor: if config.dynamo_args.embedding_worker:
await init_embedding(runtime, config)
elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config) await init_multimodal_processor(runtime, config)
elif config.dynamo_args.multimodal_encode_worker: elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config) await init_multimodal_encode_worker(runtime, config)
...@@ -158,6 +161,63 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -158,6 +161,63 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler.cleanup() handler.cleanup()
async def init_embedding(runtime: DistributedRuntime, config: Config):
"""Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler(component, engine, config, publisher)
health_check_payload = SglangHealthCheckPayload(engine).to_dict()
try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
output_type=ModelType.Embedding,
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve embedding endpoints: {e}")
raise
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config): async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal processor component""" """Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
......
...@@ -43,6 +43,18 @@ class PreprocessedRequest(BaseModel): ...@@ -43,6 +43,18 @@ class PreprocessedRequest(BaseModel):
annotations: List[str] = Field(default_factory=list) annotations: List[str] = Field(default_factory=list)
EmbeddingInput = Union[str, List[str], List[int], List[List[int]]]
class EmbeddingRequest(BaseModel):
model: str
input: EmbeddingInput
user: Optional[str] = None
dimensions: Optional[
int
] = None # only supported in text-embedding-3 and later models from OpenAI
class DisaggPreprocessedRequest(BaseModel): class DisaggPreprocessedRequest(BaseModel):
request: Union[PreprocessedRequest, ChatCompletionRequest] request: Union[PreprocessedRequest, ChatCompletionRequest]
sampling_params: dict sampling_params: dict
......
...@@ -19,6 +19,7 @@ async def _register_llm_with_runtime_config( ...@@ -19,6 +19,7 @@ async def _register_llm_with_runtime_config(
server_args: ServerArgs, server_args: ServerArgs,
dynamo_args: DynamoArgs, dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens, input_type: Optional[ModelInput] = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions,
) -> bool: ) -> bool:
"""Register LLM with the Dynamo runtime. """Register LLM with the Dynamo runtime.
...@@ -28,19 +29,23 @@ async def _register_llm_with_runtime_config( ...@@ -28,19 +29,23 @@ async def _register_llm_with_runtime_config(
server_args: SGLang server configuration. server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration. dynamo_args: Dynamo-specific configuration.
input_type: Expected model input type. Defaults to ModelInput.Tokens. input_type: Expected model input type. Defaults to ModelInput.Tokens.
output_type: Expected model output type. Defaults to ModelType.Chat | ModelType.Completions.
Returns: Returns:
True if registration succeeded, False otherwise. True if registration succeeded, False otherwise.
""" """
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args) runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
input_type = input_type input_type = input_type
output_type = ModelType.Chat | ModelType.Completions
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
logging.warning( logging.warning(
"The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available" "The skip-tokenizer-init flag was not set. Using the sglang tokenizer/detokenizer instead. The dynamo tokenizer/detokenizer will not be used and only v1/chat/completions will be available"
) )
input_type = ModelInput.Text input_type = ModelInput.Text
output_type = ModelType.Chat # Only override output_type for chat models, not for embeddings
if output_type != ModelType.Embedding:
output_type = ModelType.Chat
try: try:
await register_llm( await register_llm(
input_type, input_type,
...@@ -134,6 +139,7 @@ async def register_llm_with_readiness_gate( ...@@ -134,6 +139,7 @@ async def register_llm_with_readiness_gate(
server_args: ServerArgs, server_args: ServerArgs,
dynamo_args: DynamoArgs, dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens, input_type: Optional[ModelInput] = ModelInput.Tokens,
output_type: Optional[ModelType] = ModelType.Chat | ModelType.Completions,
readiness_gate: Optional[asyncio.Event] = None, readiness_gate: Optional[asyncio.Event] = None,
) -> None: ) -> None:
"""Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success. """Wrapper function to register LLM with the Dynamo runtime and use optional readiness gate to signal success.
...@@ -144,6 +150,7 @@ async def register_llm_with_readiness_gate( ...@@ -144,6 +150,7 @@ async def register_llm_with_readiness_gate(
server_args: SGLang server configuration. server_args: SGLang server configuration.
dynamo_args: Dynamo-specific configuration. dynamo_args: Dynamo-specific configuration.
input_type: Expected model input type. Defaults to ModelInput.Tokens. input_type: Expected model input type. Defaults to ModelInput.Tokens.
output_type: Expected model output type. Defaults to ModelType.Chat | ModelType.Completions.
readiness_gate: Optional event to signal when registration completes. readiness_gate: Optional event to signal when registration completes.
Raises: Raises:
...@@ -155,6 +162,7 @@ async def register_llm_with_readiness_gate( ...@@ -155,6 +162,7 @@ async def register_llm_with_readiness_gate(
server_args, server_args,
dynamo_args, dynamo_args,
input_type, input_type,
output_type,
) )
if not registration_success: if not registration_success:
logging.error("Model registration failed; shutting down") logging.error("Model registration failed; shutting down")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from .decode_handler import DecodeWorkerHandler # Embedding handlers
from .embedding import EmbeddingWorkerHandler
# Base handlers # Base handlers
from .handler_base import BaseWorkerHandler from .handler_base import BaseWorkerHandler
# LLM handlers
from .llm import DecodeWorkerHandler, PrefillWorkerHandler
# Multimodal handlers # Multimodal handlers
from .multimodal_encode_worker_handler import MultimodalEncodeWorkerHandler from .multimodal import (
from .multimodal_processor_handler import MultimodalProcessorHandler MultimodalEncodeWorkerHandler,
from .multimodal_worker_handler import (
MultimodalPrefillWorkerHandler, MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
MultimodalWorkerHandler, MultimodalWorkerHandler,
) )
from .prefill_handler import PrefillWorkerHandler
__all__ = [ __all__ = [
"BaseWorkerHandler", "BaseWorkerHandler",
# LLM handlers
"DecodeWorkerHandler", "DecodeWorkerHandler",
"PrefillWorkerHandler", "PrefillWorkerHandler",
"MultimodalProcessorHandler", # Embedding handlers
"EmbeddingWorkerHandler",
# Multimodal handlers
"MultimodalEncodeWorkerHandler", "MultimodalEncodeWorkerHandler",
"MultimodalWorkerHandler",
"MultimodalPrefillWorkerHandler", "MultimodalPrefillWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalWorkerHandler",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .embedding_handler import EmbeddingWorkerHandler
__all__ = [
"EmbeddingWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
import sglang as sgl
from dynamo._core import Component
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import EmbeddingRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class EmbeddingWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
):
super().__init__(component, engine, config, publisher)
logging.info("Embedding worker handler initialized")
def cleanup(self):
self.engine.shutdown()
logging.info("Engine shutdown")
super().cleanup()
async def generate(self, request: dict):
logging.debug(f"Embedding request: {request}")
# Parse the embedding request - should only receive EmbeddingRequest format
embedding_request = EmbeddingRequest(**request)
# Handle different input types
if isinstance(embedding_request.input, str):
prompt = embedding_request.input
elif isinstance(embedding_request.input, list):
prompt = embedding_request.input
else:
raise TypeError(f"Invalid input type: {type(embedding_request.input)}")
result = await self.engine.async_encode(prompt=prompt)
# Transform the response to OpenAI format
response = self._transform_response(result, embedding_request.model)
yield response
def _transform_response(self, ret, model_name):
"""Transform SGLang response to OpenAI embedding format"""
if not isinstance(ret, list):
ret = [ret]
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
embedding_objects.append(
{
"object": "embedding",
"embedding": ret_item["embedding"],
"index": idx,
}
)
prompt_tokens += ret_item.get("meta_info", {}).get("prompt_tokens", 0)
return {
"object": "list",
"data": embedding_objects,
"model": model_name,
"usage": {
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens,
},
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .decode_handler import DecodeWorkerHandler
from .prefill_handler import PrefillWorkerHandler
__all__ = [
"DecodeWorkerHandler",
"PrefillWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .encode_worker_handler import MultimodalEncodeWorkerHandler
from .processor_handler import MultimodalProcessorHandler
from .worker_handler import MultimodalPrefillWorkerHandler, MultimodalWorkerHandler
__all__ = [
"MultimodalEncodeWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalWorkerHandler",
"MultimodalPrefillWorkerHandler",
]
...@@ -388,6 +388,18 @@ impl ModelWatcher { ...@@ -388,6 +388,18 @@ impl ModelWatcher {
.context("add_completions_model")?; .context("add_completions_model")?;
tracing::info!("Completions is ready"); tracing::info!("Completions is ready");
} }
} else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
// Case: Text + Embeddings
let push_router = PushRouter::<
NvCreateEmbeddingRequest,
Annotated<NvCreateEmbeddingResponse>,
>::from_client_with_threshold(
client, self.router_mode, self.busy_threshold
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_embeddings_model(&model_entry.name, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() { } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat // Case 3: Text + Chat
let push_router = PushRouter::< let push_router = PushRouter::<
......
...@@ -17,6 +17,8 @@ from tests.utils.payload_builder import ( ...@@ -17,6 +17,8 @@ from tests.utils.payload_builder import (
chat_payload, chat_payload,
chat_payload_default, chat_payload_default,
completion_payload_default, completion_payload_default,
embedding_payload,
embedding_payload_default,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -116,6 +118,39 @@ sglang_configs = { ...@@ -116,6 +118,39 @@ sglang_configs = {
) )
], ],
), ),
"embedding_agg": SGLangConfig(
name="embedding_agg",
directory=sglang_dir,
script_name="agg_embed.sh",
marks=[pytest.mark.gpu_1],
model="Qwen/Qwen3-Embedding-4B",
delayed_start=0,
timeout=180,
models_port=8000,
request_payloads=[
# Test default payload with multiple inputs
embedding_payload_default(
repeat_count=2,
expected_response=["Generated 2 embeddings with dimension"],
),
# Test single string input
embedding_payload(
input_text="Hello, world!",
repeat_count=1,
expected_response=["Generated 1 embeddings with dimension"],
),
# Test multiple string inputs
embedding_payload(
input_text=[
"The quick brown fox jumps over the lazy dog.",
"Machine learning is transforming technology.",
"Natural language processing enables computers to understand text.",
],
repeat_count=1,
expected_response=["Generated 3 embeddings with dimension"],
),
],
),
} }
......
...@@ -11,10 +11,12 @@ import os ...@@ -11,10 +11,12 @@ import os
QWEN = "Qwen/Qwen3-0.6B" QWEN = "Qwen/Qwen3-0.6B"
LLAMA = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # on an l4 gpu, must limit --max-seq-len, otherwise it will not fit LLAMA = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # on an l4 gpu, must limit --max-seq-len, otherwise it will not fit
QWEN_EMBEDDING = "Qwen/Qwen3-Embedding-4B"
TEST_MODELS = [ TEST_MODELS = [
QWEN, QWEN,
LLAMA, LLAMA,
QWEN_EMBEDDING,
] ]
# Env-driven defaults for specific test groups # Env-driven defaults for specific test groups
......
...@@ -4,7 +4,12 @@ ...@@ -4,7 +4,12 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from tests.utils.client import send_request from tests.utils.client import send_request
from tests.utils.payloads import ChatPayload, CompletionPayload, MetricsPayload from tests.utils.payloads import (
ChatPayload,
CompletionPayload,
EmbeddingPayload,
MetricsPayload,
)
# Common default text prompt used across tests # Common default text prompt used across tests
TEXT_PROMPT = "Tell me a short joke about AI." TEXT_PROMPT = "Tell me a short joke about AI."
...@@ -123,6 +128,47 @@ def completion_payload( ...@@ -123,6 +128,47 @@ def completion_payload(
) )
def embedding_payload_default(
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
) -> EmbeddingPayload:
return EmbeddingPayload(
body={
"input": ["The sky is blue.", "Machine learning is fascinating."],
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or ["Generated 2 embeddings with dimension"],
)
def embedding_payload(
input_text: Union[str, List[str]],
repeat_count: int = 3,
expected_response: Optional[List[str]] = None,
expected_log: Optional[List[str]] = None,
) -> EmbeddingPayload:
# Normalize input to list for consistent processing
if isinstance(input_text, str):
input_list = [input_text]
expected_count = 1
else:
input_list = input_text
expected_count = len(input_text)
return EmbeddingPayload(
body={
"input": input_list,
},
repeat_count=repeat_count,
expected_log=expected_log or [],
expected_response=expected_response
or [f"Generated {expected_count} embeddings with dimension"],
)
# Build small request-based health checks for chat and completions # Build small request-based health checks for chat and completions
# these should only be used as a last resort. Generally want to use an actual health check # these should only be used as a last resort. Generally want to use an actual health check
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment