Commit e3f14051 authored by Maksim Khadkevich's avatar Maksim Khadkevich Committed by GitHub
Browse files

feat: added back smart routing and basic vllm examples (#111)

parent d57847b2
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from common.chat_processor import ChatProcessor
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
logger = logging.getLogger("vllm")
class BaseVllmEngine:
"""
Request handler for the generate endpoint
"""
def __init__(self, engine_args: AsyncEngineArgs):
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.engine_client = None
self.chat_processor: ChatProcessor | None = None
self._engine_context = None
async def initialize(self):
"""Initialize the engine client and related components."""
logger.info("Initializing engine client")
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
self.tokenizer = await self.engine_client.get_tokenizer()
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
else:
raise RuntimeError("Failed to initialize engine client")
async def cleanup(self):
"""Cleanup resources."""
print("Cleaning up engine client")
if self._engine_context is not None:
await self._engine_context.__aexit__(None, None, None)
self._engine_context = None
self.engine_client = None
self.chat_processor = None
async def __aenter__(self):
await self.initialize()
"""Initialize with context manager syntax."""
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
@abc.abstractmethod
async def generate(self, raw_request):
pass
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
from .protocol import Request
@dynamo_worker()
async def worker(
runtime: DistributedRuntime,
component: str,
prompt: str,
max_tokens: int,
temperature: float,
):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace("dynamo").component(component).endpoint("generate")
# create client
client = await endpoint.client()
# issue request
tasks = []
for _ in range(1):
tasks.append(
client.generate(
Request(
prompt=prompt,
sampling_params={
"temperature": temperature,
"max_tokens": max_tokens,
},
).model_dump_json()
)
)
streams = await asyncio.gather(*tasks)
# process response
for stream in streams:
async for resp in stream:
print(resp)
if __name__ == "__main__":
uvloop.install()
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="what is the capital of france?")
parser.add_argument("--component", type=str, default="vllm")
parser.add_argument("--max-tokens", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.5)
args = parser.parse_args()
asyncio.run(worker(args.component, args.prompt, args.max_tokens, args.temperature))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
return AsyncEngineArgs.from_cli_args(args)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
## Overview
Pipeline Architecture:
```
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/generate)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Middle │
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Backend │
└─────────────┘
```
## Unified serve
1. Launch all three services using a single command -
```bash
cd /workspace/examples/python_rs/llm/vllm
dynamo-sdk serve sdk_basic_service.basic:Frontend
```
2. Send request to frontend using curl -
```bash
curl -X 'POST' \
'http://localhost:3000/generate' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"text": "test"
}'
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
from dynamo.sdk import api, depends, dynamo_endpoint, service
"""
Pipeline Architecture:
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/generate)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Middle │
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Backend │
└─────────────┘
"""
class RequestType(BaseModel):
text: str
class ResponseType(BaseModel):
text: str
@service(
resources={"cpu": "2"},
traffic={"timeout": 30},
dynamo={
"enabled": True,
"namespace": "inference",
},
workers=3,
)
class Backend:
def __init__(self) -> None:
print("Starting backend")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Generate tokens."""
req_text = req.text
print(f"Backend received: {req_text}")
text = f"{req_text}-back"
for token in text.split():
yield f"Backend: {token}"
@service(
resources={"cpu": "2"},
traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"},
)
class Middle:
backend = depends(Backend)
def __init__(self) -> None:
print("Starting middle")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
print(f"Middle received: {req_text}")
text = f"{req_text}-mid"
next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request):
print(f"Middle received response: {response}")
yield f"Middle: {response}"
@service(resources={"cpu": "1"}, traffic={"timeout": 60}) # Regular HTTP API
class Frontend:
middle = depends(Middle)
def __init__(self) -> None:
print("Starting frontend")
@api
async def generate(self, text):
"""Stream results from the pipeline."""
print(f"Frontend received: {text}")
print(f"Frontend received type: {type(text)}")
txt = RequestType(text=text)
print(f"Frontend sending: {type(txt)}")
async for response in self.middle.generate(txt.model_dump_json()):
yield f"Frontend: {response}"
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
Run this example using command below
```bash
cd /workspace/examples/python_rs/llm/vllm
dynamo-sdk serve sdk_kv_router.frontend:Frontend
```
Send request to http service:
```bash
curl -X 'POST' \
'http://localhost:3000/chat_completion' \
-H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \
-d '{
"msg": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}'
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sdk_kv_router.processor import Processor
from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
@service(traffic={"timeout": 10000}, image=DYNAMO_IMAGE)
class Frontend:
processor = depends(Processor)
def __init__(self):
print("frontend init")
@api
async def chat_completion(self, msg: str):
# Call the generate method
generator = self.processor.generate(
{
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"messages": [{"role": "user", "content": msg}],
"stream": True,
"max_tokens": 10,
}
)
# Now iterate over the async generator
async for response in generator:
print("client response_data:", response)
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from typing import AsyncIterator
import bentoml
from sdk_kv_router.router import Router
from sdk_kv_router.worker import VllmEngine
with bentoml.importing():
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from common.chat_processor import ChatProcessor, ProcessMixIn
from common.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from dynamo.sdk import depends, dynamo_context, dynamo_endpoint, service
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
workers = depends(VllmEngine)
router = depends(Router)
def __init__(self):
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.engine_args = AsyncEngineArgs(
model=model,
tokenizer=model,
enable_prefix_caching=True,
block_size=64,
max_model_len=16384,
)
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer()
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
def _create_tokenizer(self) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = self.engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def generate_responses(
self, engine_generator
) -> AsyncIterator[RequestOutput]:
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
yield RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
@dynamo_endpoint()
async def generate(self, raw_request: ChatCompletionRequest):
request_id = str(uuid.uuid4())
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_id = None
async for worker in self.router.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
):
worker_id = worker
break
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmEngine.dynamo_address() # type: ignore
worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
if worker_id == "":
engine_generator = await worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
else:
engine_generator = await worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
uuid.UUID(worker_id).int,
)
output = self.generate_responses(engine_generator)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from enum import Enum
import bentoml
from common.protocol import Tokens
from dynamo.sdk import async_onstart, dynamo_context, dynamo_endpoint, service
with bentoml.importing():
from dynamo.runtime import KvRouter
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(self):
self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.routing_strategy = RoutingStrategy.PREFIX
self.runtime = dynamo_context["runtime"]
self.min_workers = 1
self.kv_block_size = 64
self.router: KvRouter = None
@async_onstart
async def init_engine(self):
workers_client = (
await self.runtime.namespace("dynamo")
.component("VllmEngine")
.endpoint("generate")
.client()
)
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
print("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < self.min_workers:
print(
f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {self.min_workers}"
)
await asyncio.sleep(5)
kv_listener = self.runtime.namespace("dynamo").component(self.model_name)
await kv_listener.create_service()
self.router = KvRouter(self.runtime, kv_listener, self.kv_block_size)
@dynamo_endpoint()
async def generate(self, request: Tokens):
lora_id = 0
worker_id = ""
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
except Exception as e:
if "No worker found" in str(e):
worker_id = ""
else:
print(f"Error during worker selection: {e}")
print(f"Scheduling to worker_id: {worker_id}")
yield worker_id
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from typing import Optional
import bentoml
with bentoml.importing():
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind
from common.base_engine import BaseVllmEngine
from common.protocol import MyRequestOutput, vLLMGenerateRequest
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import (
async_onstart,
dynamo_context,
dynamo_endpoint,
server_context,
service,
)
lease_id = None
## TODO: metrics_publisher.create_endpoint(worker_component),
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEngine(BaseVllmEngine):
"""
vLLM Inference Engine
"""
def __init__(self):
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.engine_args = AsyncEngineArgs(
model=model,
gpu_memory_utilization=0.8,
enable_prefix_caching=True,
block_size=64,
max_model_len=16384,
)
VLLM_WORKER_ID = dynamo_context["endpoints"][0].lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
os.environ["VLLM_KV_NAMESPACE"] = "dynamo"
os.environ["VLLM_KV_COMPONENT"] = "vllm"
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
os.environ["CUDA_VISIBLE_DEVICES"] = f"{server_context.worker_index - 1}"
self.metrics_publisher = KvMetricsPublisher()
self.engine_client: Optional[MQLLMEngineClient] = None
super().__init__(self.engine_args)
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
await self.metrics_publisher.create_service(component)
@async_onstart
async def init_engine(self):
if self.engine_client is None:
await super().initialize()
print("vLLM worker initialized")
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
self.metrics_publisher.publish(
0,
1024,
0,
1024,
0,
0,
0,
)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(lambda _: print("metrics publisher endpoint created"))
@dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
sampling_params = request.sampling_params
# rust HTTP requires Delta streaming
sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate( # type: ignore
request.engine_prompt, sampling_params, request.request_id
):
# MyRequestOutput takes care of serializing the response as
# vLLM's RequestOutput is not serializable by default
resp = MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
yield resp
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment