Commit 15de1807 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: Add http support to vllm kv router example (#217)


Co-authored-by: default avatarSean Choi <choishsean@gmail.com>
Co-authored-by: default avataraflowers <aflowers@nvidia.com>
parent 0439d3b5
...@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>] ...@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example: Example:
```bash ```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model # Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
/workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B bash /workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions # List tmux sessions
tmux ls tmux ls
...@@ -203,15 +203,31 @@ source /opt/triton/venv/bin/activate ...@@ -203,15 +203,31 @@ source /opt/triton/venv/bin/activate
# Launch prefill worker # Launch prefill worker
cd /workspace/examples/python_rs/llm/vllm cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m kv_router.router \ RUST_LOG=info python3 -m kv_router.router \
--routing-strategy prefix --routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--min-workers 1
``` ```
You can choose between different routing strategies: You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match. - `prefix`: Route requests to the worker that has the longest prefix match.
- `round_robin`: Route requests to the worker in a round-robin manner.
- `random`: Route requests to a random worker.
**Terminal 2 and 3 - Workers:** **Terminal 2 - Processor:**
```bash
# Activate virtual environment
source /opt/triton/venv/bin/activate
# Processor must take the same args as the worker
# This is temporary until we communicate the ModelDeploymentCard over etcd
cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m kv_router.processor \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384
```
**Terminal 3 and 4 - Workers:**
```bash ```bash
# Activate virtual environment # Activate virtual environment
source /opt/triton/venv/bin/activate source /opt/triton/venv/bin/activate
...@@ -229,20 +245,45 @@ RUST_LOG=info python3 -m kv_router.worker \ ...@@ -229,20 +245,45 @@ RUST_LOG=info python3 -m kv_router.worker \
Note: Must enable prefix caching for KV Router to work Note: Must enable prefix caching for KV Router to work
Note: block-size must be 64, otherwise Router won't work (accepts only 64 tokens) Note: block-size must be 64, otherwise Router won't work (accepts only 64 tokens)
**Terminal 3 - Client:** **Terminal 5 - Client:**
Don't forget to add the model to the server:
```bash ```bash
# Activate virtual environment llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init.process.chat/completions
source /opt/triton/venv/bin/activate ```
# Run client ```bash
# We use a long prompt to populate a few KV Blocks (64 tokens each) curl localhost:9992/v1/chat/completions -H "Content-Type: application/json" -d '{
# Try running it a few times to see where the router is sending the request "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
cd /workspace/examples/python_rs/llm/vllm "messages": [
python3 -m common.client \ {
--prompt "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden." \ "role": "user",
--component preprocess \ "content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
--max-tokens 10 \ }
--temperature 0.5 ],
"stream":false,
"max_tokens": 30
}'
```
Expected output:
```json
{
"id": "f435d1aa-d423-40a0-a616-00bc428a3e32",
"choices": [
{
"message": {
"role": "assistant",
"content": "Alright, the user is playing a character in a D&D setting. They want a detailed background for their character, set in the world of Eldoria, particularly in the city of Aeloria. The user mentioned it's about an intrepid explorer"
},
"index": 0,
"finish_reason": "length"
}
],
"created": 1740020570,
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"object": "chat.completion",
"usage": null,
"system_fingerprint": null
}
``` ```
### 6. Known Issues and Limitations ### 6. Known Issues and Limitations
......
...@@ -35,18 +35,19 @@ class BaseVllmEngine: ...@@ -35,18 +35,19 @@ class BaseVllmEngine:
self.engine_args = engine_args self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config() self.model_config = self.engine_args.create_model_config()
self.engine_client = None self.engine_client = None
self.chat_processor = None self.chat_processor: ChatProcessor | None = None
self._engine_context = None self._engine_context = None
async def initialize(self): async def initialize(self):
"""Initialize the engine client and related components.""" """Initialize the engine client and related components."""
print("Initializing engine client") logger.info("Initializing engine client")
self._engine_context = build_async_engine_client_from_engine_args( self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args self.engine_args
) )
if self._engine_context is not None: if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__() self.engine_client = await self._engine_context.__aenter__()
self.chat_processor = ChatProcessor(self.engine_client, self.model_config) self.tokenizer = await self.engine_client.get_tokenizer()
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
else: else:
raise RuntimeError("Failed to initialize engine client") raise RuntimeError("Failed to initialize engine client")
...@@ -67,34 +68,6 @@ class BaseVllmEngine: ...@@ -67,34 +68,6 @@ class BaseVllmEngine:
async def __aexit__(self, exc_type, exc_value, traceback): async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup() await self.cleanup()
async def _parse_raw_request(self, raw_request):
assert self.engine_client is not None
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
request_prompt,
engine_prompt,
) = await self.chat_processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return request, conversation, request_prompt, engine_prompt, sampling_params
async def _stream_response(self, request, generator, request_id, conversation):
assert self.engine_client is not None
return self.chat_processor.stream_response(
request,
generator,
request_id,
conversation,
)
@abc.abstractmethod @abc.abstractmethod
async def generate(self, raw_request): async def generate(self, raw_request):
pass pass
...@@ -14,20 +14,75 @@ ...@@ -14,20 +14,75 @@
# limitations under the License. # limitations under the License.
import json import json
from typing import AsyncIterator, List from typing import AsyncIterator, List, Protocol, runtime_checkable
import vllm from vllm import TokensPrompt
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
RequestResponseMetadata, RequestResponseMetadata,
) )
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, chat_processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
async def _parse_raw_request(self, raw_request):
if self.chat_processor is None:
raise RuntimeError("chat_processor has not been initialized")
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
request_prompt,
engine_prompt,
) = await self.chat_processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return request, conversation, request_prompt, engine_prompt, sampling_params
async def _stream_response(self, request, generator, request_id, conversation):
if self.chat_processor is None:
raise RuntimeError("chat_processor has not been initialized")
return self.chat_processor.stream_response(
request,
generator,
request_id,
conversation,
)
class ChatProcessor: class ChatProcessor:
def __init__(self, engine_client: vllm.AsyncLLMEngine, model_config: ModelConfig): def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.engine_client = engine_client self.tokenizer = tokenizer
self.model_config = model_config self.model_config = model_config
self.openai_serving = OpenAIServingChat( self.openai_serving = OpenAIServingChat(
engine_client=None, engine_client=None,
...@@ -42,9 +97,10 @@ class ChatProcessor: ...@@ -42,9 +97,10 @@ class ChatProcessor:
def parse_raw_request(self, raw_request: dict) -> ChatCompletionRequest: def parse_raw_request(self, raw_request: dict) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request) return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: dict): async def preprocess(
self, raw_request: dict
) -> tuple[ConversationMessage, RequestPrompt, TokensPrompt]:
request = self.parse_raw_request(raw_request) request = self.parse_raw_request(raw_request)
tokenizer = await self.engine_client.get_tokenizer()
( (
conversation, conversation,
...@@ -52,9 +108,9 @@ class ChatProcessor: ...@@ -52,9 +108,9 @@ class ChatProcessor:
engine_prompts, engine_prompts,
) = await self.openai_serving._preprocess_chat( ) = await self.openai_serving._preprocess_chat(
request, request,
tokenizer, self.tokenizer,
request.messages, request.messages,
chat_template=request.chat_template or tokenizer.chat_template, chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format, chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message, continue_final_message=request.continue_final_message,
...@@ -75,7 +131,6 @@ class ChatProcessor: ...@@ -75,7 +131,6 @@ class ChatProcessor:
request_id: str, request_id: str,
conversation: List, conversation: List,
): ):
tokenizer = await self.engine_client.get_tokenizer()
request_metadata = RequestResponseMetadata(request_id=request_id) request_metadata = RequestResponseMetadata(request_id=request_id)
assert request.stream, "Only stream is supported" assert request.stream, "Only stream is supported"
async for raw_response in self.openai_serving.chat_completion_stream_generator( async for raw_response in self.openai_serving.chat_completion_stream_generator(
...@@ -84,7 +139,7 @@ class ChatProcessor: ...@@ -84,7 +139,7 @@ class ChatProcessor:
request_id, request_id,
request.model, request.model,
conversation, conversation,
tokenizer, self.tokenizer,
request_metadata, request_metadata,
): ):
if raw_response.startswith("data: [DONE]"): if raw_response.startswith("data: [DONE]"):
......
...@@ -14,7 +14,15 @@ ...@@ -14,7 +14,15 @@
# limitations under the License. # limitations under the License.
from pydantic import BaseModel import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm import CompletionOutput, SamplingParams, TokensPrompt
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel): class Request(BaseModel):
...@@ -26,10 +34,6 @@ class Tokens(BaseModel): ...@@ -26,10 +34,6 @@ class Tokens(BaseModel):
tokens: list[int] tokens: list[int]
class TokenizedRequest(Request, Tokens):
pass
class PrefillRequest(Request): class PrefillRequest(Request):
request_id: str request_id: str
...@@ -40,3 +44,70 @@ class Response(BaseModel): ...@@ -40,3 +44,70 @@ class Response(BaseModel):
class PrefillResponse(BaseModel): class PrefillResponse(BaseModel):
prefilled: bool prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
...@@ -21,6 +21,7 @@ import uuid ...@@ -21,6 +21,7 @@ import uuid
import msgspec import msgspec
import uvloop import uvloop
from common.base_engine import BaseVllmEngine from common.base_engine import BaseVllmEngine
from common.chat_processor import ProcessMixIn
from common.parser import parse_vllm_args from common.parser import parse_vllm_args
from common.protocol import PrefillRequest from common.protocol import PrefillRequest
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
...@@ -32,7 +33,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -32,7 +33,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
class VllmDecodeEngine(BaseVllmEngine): class VllmDecodeEngine(BaseVllmEngine, ProcessMixIn):
""" """
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
......
...@@ -15,34 +15,74 @@ ...@@ -15,34 +15,74 @@
#!/bin/bash #!/bin/bash
# LIMITATIONS:
# - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value
# - Must use a single node
if [ $# -lt 2 ]; then if [ $# -lt 2 ]; then
echo "Usage: $0 <number_of_workers> <routing_strategy> [model_name]" echo "Usage: $0 <number_of_workers> <routing_strategy> [model_name] [endpoint_name]"
echo "Error: Must specify at least number of workers and routing strategy" echo "Error: Must specify at least number of workers and routing strategy"
echo "Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)"
echo "Optional: endpoint_name (default: triton-init.process.chat/completions)"
exit 1 exit 1
fi fi
NUM_WORKERS=$1 NUM_WORKERS=$1
ROUTING_STRATEGY=$2 ROUTING_STRATEGY=$2
MODEL_NAME=${3:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"} MODEL_NAME=${3:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
VALID_STRATEGIES=("prefix" "round_robin" "random") ENDPOINT_NAME=${4:-"triton-init.process.chat/completions"}
VALID_STRATEGIES=("prefix")
SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/triton/venv/bin/activate && cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}" echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
exit 1 exit 1
fi fi
########################################################
# HTTP Server
########################################################
HTTP_CMD="TRD_LOG=DEBUG http"
tmux new-session -d -s "$SESSION_NAME-http"
tmux send-keys -t "$SESSION_NAME-http" "$INIT_CMD && $HTTP_CMD" C-m
SESSION_NAME="v" ########################################################
WORKDIR="/workspace/examples/python_rs/llm/vllm" # LLMCTL
INIT_CMD="source /opt/triton/venv/bin/activate && cd $WORKDIR" ########################################################
LLMCTL_CMD="sleep 5 && llmctl http remove chat-model $MODEL_NAME && \
llmctl http add chat-model $MODEL_NAME $ENDPOINT_NAME && \
llmctl http list chat-model"
tmux new-session -d -s "$SESSION_NAME-llmctl"
tmux send-keys -t "$SESSION_NAME-llmctl" "$INIT_CMD && $LLMCTL_CMD" C-m
########################################################
# Processor
########################################################
# For now processor gets same args as worker, need to have them communicate over etcd
PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384 "
tmux new-session -d -s "$SESSION_NAME-processor"
tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
########################################################
# Router
########################################################
ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \ ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \
--model $MODEL_NAME \
--routing-strategy $ROUTING_STRATEGY \ --routing-strategy $ROUTING_STRATEGY \
--min-workers $NUM_WORKERS " --min-workers $NUM_WORKERS "
tmux new-session -d -s "$SESSION_NAME-router" tmux new-session -d -s "$SESSION_NAME-router"
tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
########################################################
# Workers
########################################################
WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker \ WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker \
--model $MODEL_NAME \ --model $MODEL_NAME \
--tokenizer $MODEL_NAME \ --tokenizer $MODEL_NAME \
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from typing import AsyncIterator
import uvloop
from common.chat_processor import ChatProcessor, ProcessMixIn
from common.parser import parse_vllm_args
from common.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from transformers import AutoTokenizer
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from triton_distributed_rs._core import Client
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
router_client: Client,
workers_client: Client,
):
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer(engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.router_client = router_client
self.workers_client = workers_client
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def generate_responses(
self, engine_generator
) -> AsyncIterator[RequestOutput]:
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
yield RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
request_id = str(uuid.uuid4())
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_id_generator: AsyncIterator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
worker_id = (
await worker_id_generator.__anext__()
) # only one worker id is returned
worker_id = worker_id.data()
vllm_logger.info(f"Worker ID: {worker_id}")
if worker_id == "":
engine_generator = await self.workers_client.random(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
else:
engine_generator = await self.workers_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
uuid.UUID(worker_id).int,
)
output = self.generate_responses(engine_generator)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Set up clients to the router and workers.
Serve the triton-init.process.chat/completions endpoint.
"""
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate")
.client()
)
router_client = (
await runtime.namespace("triton-init")
.component("router")
.endpoint("generate")
.client()
)
preprocess_component = runtime.namespace("triton-init").component("process")
await preprocess_component.create_service()
preprocess_endpoint = preprocess_component.endpoint("chat/completions")
processor = Processor(engine_args, router_client, workers_client)
assert isinstance(processor, ProcessMixIn)
await preprocess_endpoint.serve_endpoint(processor.generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
import asyncio import asyncio
import uuid
from argparse import Namespace from argparse import Namespace
from enum import Enum from enum import Enum
from typing import AsyncIterator
import uvloop import uvloop
from common.protocol import Response, TokenizedRequest from common.protocol import Tokens
from triton_distributed_rs import ( from triton_distributed_rs import (
DistributedRuntime, DistributedRuntime,
KvRouter, KvRouter,
...@@ -29,6 +29,8 @@ from triton_distributed_rs import ( ...@@ -29,6 +29,8 @@ from triton_distributed_rs import (
) )
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
WorkerId = str
class RoutingStrategy(Enum): class RoutingStrategy(Enum):
PREFIX = "prefix" PREFIX = "prefix"
...@@ -43,19 +45,17 @@ class Router: ...@@ -43,19 +45,17 @@ class Router:
def __init__( def __init__(
self, self,
router, router: KvRouter,
workers_client,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX, routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
): ):
vllm_logger.info( vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}" f"Initializing KV Router with strategy: {routing_strategy.value}"
) )
self.router = router self.router = router
self.workers_client = workers_client
self.routing_strategy = routing_strategy self.routing_strategy = routing_strategy
@triton_endpoint(TokenizedRequest, Response) @triton_endpoint(Tokens, WorkerId)
async def generate(self, request): async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0 lora_id = 0
worker_id = "" worker_id = ""
if self.routing_strategy == RoutingStrategy.PREFIX: if self.routing_strategy == RoutingStrategy.PREFIX:
...@@ -70,35 +70,36 @@ class Router: ...@@ -70,35 +70,36 @@ class Router:
vllm_logger.info(f"Scheduling to worker_id: {worker_id}") vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if self.routing_strategy == RoutingStrategy.ROUND_ROBIN: yield worker_id
engine_generator = await self.workers_client.round_robin(
request.model_dump_json()
)
elif self.routing_strategy == RoutingStrategy.RANDOM or worker_id == "":
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else: else:
# extract back lease_id # TODO: Do we implement round_robin and random here?
engine_generator = await self.workers_client.direct( # or just skip this router and directly enable in preprocess?
request.model_dump_json(), uuid.UUID(worker_id).int raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
) )
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_worker() @triton_worker()
async def worker(runtime: DistributedRuntime, args: Namespace): async def worker(runtime: DistributedRuntime, args: Namespace):
"""
Set up the worker clients.
Serve the triton-init.router.generate endpoint.
"""
workers_client = ( workers_client = (
await runtime.namespace("triton-init") await runtime.namespace("triton-init")
.component("vllm") .component("vllm")
.endpoint("generate_from_tokens") .endpoint("generate")
.client() .client()
) )
vllm_logger.info("Waiting for workers to be ready") wait_task = workers_client.wait_for_endpoints()
await workers_client.wait_for_endpoints() await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers: while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info( vllm_logger.info(
...@@ -112,22 +113,16 @@ async def worker(runtime: DistributedRuntime, args: Namespace): ...@@ -112,22 +113,16 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
) )
# TODO Router is a fixed namespace separate from the others # TODO Router is a fixed namespace separate from the others
kv_listener = runtime.namespace("router").component( kv_listener = runtime.namespace("router").component(args.model_name)
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
await kv_listener.create_service() await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router") router_component = runtime.namespace("triton-init").component("router")
await router_component.create_service() await router_component.create_service()
router = None router = KvRouter(runtime, kv_listener)
if args.routing_strategy == RoutingStrategy.PREFIX:
router = KvRouter(runtime, kv_listener)
endpoint = router_component.endpoint("generate") endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint( await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
Router(router, workers_client, args.routing_strategy).generate
)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -149,6 +144,12 @@ if __name__ == "__main__": ...@@ -149,6 +144,12 @@ if __name__ == "__main__":
default=1, default=1,
help="Minimum number of workers required before proceeding", help="Minimum number of workers required before proceeding",
) )
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(worker(args)) asyncio.run(worker(args))
...@@ -16,111 +16,75 @@ ...@@ -16,111 +16,75 @@
import asyncio import asyncio
import os import os
import uuid import uuid
from typing import Optional from typing import AsyncIterator
import uvloop import uvloop
import vllm from common.base_engine import BaseVllmEngine
from common.parser import parse_vllm_args from common.parser import parse_vllm_args
from common.protocol import Request, Response, TokenizedRequest from common.protocol import MyRequestOutput, vLLMGenerateRequest
from triton_distributed_rs import ( from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
DistributedRuntime,
KvRouter,
triton_endpoint,
triton_worker,
)
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.sampling_params import RequestOutputKind
vllm_logger.info(f"VLLM_KV_CAPI_PATH: {os.environ['VLLM_KV_CAPI_PATH']}") vllm_logger.info(f"VLLM_KV_CAPI_PATH: {os.environ['VLLM_KV_CAPI_PATH']}")
class VllmEngine: class VllmEngine(BaseVllmEngine):
""" """
Request handler for the generate endpoint vLLM Inference Engine
""" """
def __init__(self, engine_args: AsyncEngineArgs, router: KvRouter): def __init__(self, engine_args: AsyncEngineArgs):
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args) self.engine_args = engine_args
self.router = router super().__init__(engine_args)
self.tokenizer: Optional[AnyTokenizer] = None
# Pattern to initialize async object as python __init__ is not async @triton_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def init(self): async def generate(self, request) -> AsyncIterator:
self.tokenizer = await self.engine.get_tokenizer() if self.engine_client is None:
return self await self.initialize()
assert self.engine_client is not None, "engine_client was not initialized"
@triton_endpoint(TokenizedRequest, Response) sampling_params = request.sampling_params
async def generate_from_tokens(self, request): # rust HTTP requires Delta streaming
tokens_prompt = TokensPrompt(prompt_token_ids=request.tokens) sampling_params.output_kind = RequestOutputKind.DELTA
sampling_params = vllm.SamplingParams(**request.sampling_params) async for response in self.engine_client.generate(
request_id = str(uuid.uuid4()) request.engine_prompt, sampling_params, request.request_id
async for response in self.engine.generate(
tokens_prompt, sampling_params, request_id
): ):
yield response.outputs[0].text # MyRequestOutput takes care of serializing the response as
# vLLM's RequestOutput is not serializable by default
@triton_endpoint(Request, Response) yield MyRequestOutput(
async def generate_from_prompt(self, request): request_id=response.request_id,
sampling_params = vllm.SamplingParams(**request.sampling_params) prompt=response.prompt,
request_id = str(uuid.uuid4()) prompt_token_ids=response.prompt_token_ids,
async for response in self.engine.generate( prompt_logprobs=response.prompt_logprobs,
request.prompt, sampling_params, request_id outputs=response.outputs,
): finished=response.finished,
yield response.outputs[0].text ).model_dump_json()
@triton_endpoint(Request, Response)
async def preprocess(self, request):
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized. Must run init().")
tokens = self.tokenizer.encode(request.prompt)
engine_generator = await self.router.generate(
TokenizedRequest(tokens=tokens, **request.model_dump()).model_dump_json()
)
async for resp in engine_generator:
yield resp.data()
@triton_worker() @triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
""" """
Instantiate a `backend` component and serve the `generate` endpoint Serve the triton-init.vllm.generate endpoint.
A `Component` can serve multiple endpoints
""" """
worker_component = runtime.namespace("triton-init").component("vllm") worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service() await worker_component.create_service()
preprocess_component = runtime.namespace("triton-init").component("preprocess") worker_endpoint = worker_component.endpoint("generate")
await preprocess_component.create_service()
router_client = (
await runtime.namespace("triton-init")
.component("router")
.endpoint("generate")
.client()
)
worker_from_tokens_endpoint = worker_component.endpoint("generate_from_tokens")
worker_from_prompt_endpoint = worker_component.endpoint("generate")
preprocess_endpoint = preprocess_component.endpoint("generate")
# TODO Hack until we unify lease_id and worker_id # KV Publisher and Aggregator requires a UUID (str)
VLLM_WORKER_ID = uuid.UUID(int=worker_from_tokens_endpoint.lease_id()) # KV Router requires a lease_id (int)
# This allows us to please both, until they are unified
# If VLLM_WORKER_ID is not set, KV Routing will fail
VLLM_WORKER_ID = uuid.UUID(int=worker_endpoint.lease_id())
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID) os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}") vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args, router_client) vllm_engine = VllmEngine(engine_args)
vllm_engine = await vllm_engine.init()
await asyncio.gather( await worker_endpoint.serve_endpoint(vllm_engine.generate)
worker_from_tokens_endpoint.serve_endpoint(vllm_engine.generate_from_tokens),
worker_from_prompt_endpoint.serve_endpoint(vllm_engine.generate_from_prompt),
preprocess_endpoint.serve_endpoint(vllm_engine.preprocess),
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,6 +19,7 @@ import uuid ...@@ -19,6 +19,7 @@ import uuid
import uvloop import uvloop
from common.base_engine import BaseVllmEngine from common.base_engine import BaseVllmEngine
from common.chat_processor import ProcessMixIn
from common.parser import parse_vllm_args from common.parser import parse_vllm_args
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -29,7 +30,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -29,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
class VllmEngine(BaseVllmEngine): class VllmEngine(BaseVllmEngine, ProcessMixIn):
""" """
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment