Commit 15de1807 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: Add http support to vllm kv router example (#217)


Co-authored-by: default avatarSean Choi <choishsean@gmail.com>
Co-authored-by: default avataraflowers <aflowers@nvidia.com>
parent 0439d3b5
......@@ -180,7 +180,7 @@ kv-router-run.sh <number_of_workers> <routing_strategy> Optional[<model_name>]
Example:
```bash
# Launch 8 workers with prefix routing strategy and use deepseek-ai/DeepSeek-R1-Distill-Llama-8B as the model
/workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
bash /workspace/examples/python_rs/llm/vllm/kv-router-run.sh 8 prefix deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# List tmux sessions
tmux ls
......@@ -203,15 +203,31 @@ source /opt/triton/venv/bin/activate
# Launch prefill worker
cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m kv_router.router \
--routing-strategy prefix
--routing-strategy prefix \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--min-workers 1
```
You can choose between different routing strategies:
You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match.
- `round_robin`: Route requests to the worker in a round-robin manner.
- `random`: Route requests to a random worker.
**Terminal 2 and 3 - Workers:**
**Terminal 2 - Processor:**
```bash
# Activate virtual environment
source /opt/triton/venv/bin/activate
# Processor must take the same args as the worker
# This is temporary until we communicate the ModelDeploymentCard over etcd
cd /workspace/examples/python_rs/llm/vllm
RUST_LOG=info python3 -m kv_router.processor \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384
```
**Terminal 3 and 4 - Workers:**
```bash
# Activate virtual environment
source /opt/triton/venv/bin/activate
......@@ -229,20 +245,45 @@ RUST_LOG=info python3 -m kv_router.worker \
Note: Must enable prefix caching for KV Router to work
Note: block-size must be 64, otherwise Router won't work (accepts only 64 tokens)
**Terminal 3 - Client:**
**Terminal 5 - Client:**
Don't forget to add the model to the server:
```bash
# Activate virtual environment
source /opt/triton/venv/bin/activate
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B triton-init.process.chat/completions
```
# Run client
# We use a long prompt to populate a few KV Blocks (64 tokens each)
# Try running it a few times to see where the router is sending the request
cd /workspace/examples/python_rs/llm/vllm
python3 -m common.client \
--prompt "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden." \
--component preprocess \
--max-tokens 10 \
--temperature 0.5
```bash
curl localhost:9992/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
"role": "user",
"content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}
],
"stream":false,
"max_tokens": 30
}'
```
Expected output:
```json
{
"id": "f435d1aa-d423-40a0-a616-00bc428a3e32",
"choices": [
{
"message": {
"role": "assistant",
"content": "Alright, the user is playing a character in a D&D setting. They want a detailed background for their character, set in the world of Eldoria, particularly in the city of Aeloria. The user mentioned it's about an intrepid explorer"
},
"index": 0,
"finish_reason": "length"
}
],
"created": 1740020570,
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"object": "chat.completion",
"usage": null,
"system_fingerprint": null
}
```
### 6. Known Issues and Limitations
......
......@@ -35,18 +35,19 @@ class BaseVllmEngine:
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.engine_client = None
self.chat_processor = None
self.chat_processor: ChatProcessor | None = None
self._engine_context = None
async def initialize(self):
"""Initialize the engine client and related components."""
print("Initializing engine client")
logger.info("Initializing engine client")
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
self.chat_processor = ChatProcessor(self.engine_client, self.model_config)
self.tokenizer = await self.engine_client.get_tokenizer()
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
else:
raise RuntimeError("Failed to initialize engine client")
......@@ -67,34 +68,6 @@ class BaseVllmEngine:
async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()
async def _parse_raw_request(self, raw_request):
assert self.engine_client is not None
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
request_prompt,
engine_prompt,
) = await self.chat_processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return request, conversation, request_prompt, engine_prompt, sampling_params
async def _stream_response(self, request, generator, request_id, conversation):
assert self.engine_client is not None
return self.chat_processor.stream_response(
request,
generator,
request_id,
conversation,
)
@abc.abstractmethod
async def generate(self, raw_request):
pass
......@@ -14,20 +14,75 @@
# limitations under the License.
import json
from typing import AsyncIterator, List
from typing import AsyncIterator, List, Protocol, runtime_checkable
import vllm
from vllm import TokensPrompt
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, chat_processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
async def _parse_raw_request(self, raw_request):
if self.chat_processor is None:
raise RuntimeError("chat_processor has not been initialized")
request = self.chat_processor.parse_raw_request(raw_request)
(
conversation,
request_prompt,
engine_prompt,
) = await self.chat_processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return request, conversation, request_prompt, engine_prompt, sampling_params
async def _stream_response(self, request, generator, request_id, conversation):
if self.chat_processor is None:
raise RuntimeError("chat_processor has not been initialized")
return self.chat_processor.stream_response(
request,
generator,
request_id,
conversation,
)
class ChatProcessor:
def __init__(self, engine_client: vllm.AsyncLLMEngine, model_config: ModelConfig):
self.engine_client = engine_client
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
......@@ -42,9 +97,10 @@ class ChatProcessor:
def parse_raw_request(self, raw_request: dict) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: dict):
async def preprocess(
self, raw_request: dict
) -> tuple[ConversationMessage, RequestPrompt, TokensPrompt]:
request = self.parse_raw_request(raw_request)
tokenizer = await self.engine_client.get_tokenizer()
(
conversation,
......@@ -52,9 +108,9 @@ class ChatProcessor:
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
tokenizer,
self.tokenizer,
request.messages,
chat_template=request.chat_template or tokenizer.chat_template,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
......@@ -75,7 +131,6 @@ class ChatProcessor:
request_id: str,
conversation: List,
):
tokenizer = await self.engine_client.get_tokenizer()
request_metadata = RequestResponseMetadata(request_id=request_id)
assert request.stream, "Only stream is supported"
async for raw_response in self.openai_serving.chat_completion_stream_generator(
......@@ -84,7 +139,7 @@ class ChatProcessor:
request_id,
request.model,
conversation,
tokenizer,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
......
......@@ -14,7 +14,15 @@
# limitations under the License.
from pydantic import BaseModel
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm import CompletionOutput, SamplingParams, TokensPrompt
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
......@@ -26,10 +34,6 @@ class Tokens(BaseModel):
tokens: list[int]
class TokenizedRequest(Request, Tokens):
pass
class PrefillRequest(Request):
request_id: str
......@@ -40,3 +44,70 @@ class Response(BaseModel):
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
......@@ -21,6 +21,7 @@ import uuid
import msgspec
import uvloop
from common.base_engine import BaseVllmEngine
from common.chat_processor import ProcessMixIn
from common.parser import parse_vllm_args
from common.protocol import PrefillRequest
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
......@@ -32,7 +33,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import logger as vllm_logger
class VllmDecodeEngine(BaseVllmEngine):
class VllmDecodeEngine(BaseVllmEngine, ProcessMixIn):
"""
Request handler for the generate endpoint
"""
......
......@@ -15,34 +15,74 @@
#!/bin/bash
# LIMITATIONS:
# - Must use a single GPU for workers as CUDA_VISIBLE_DEVICES is set to a fixed value
# - Must use a single node
if [ $# -lt 2 ]; then
echo "Usage: $0 <number_of_workers> <routing_strategy> [model_name]"
echo "Usage: $0 <number_of_workers> <routing_strategy> [model_name] [endpoint_name]"
echo "Error: Must specify at least number of workers and routing strategy"
echo "Optional: model_name (default: deepseek-ai/DeepSeek-R1-Distill-Llama-8B)"
echo "Optional: endpoint_name (default: triton-init.process.chat/completions)"
exit 1
fi
NUM_WORKERS=$1
ROUTING_STRATEGY=$2
MODEL_NAME=${3:-"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"}
VALID_STRATEGIES=("prefix" "round_robin" "random")
ENDPOINT_NAME=${4:-"triton-init.process.chat/completions"}
VALID_STRATEGIES=("prefix")
SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/triton/venv/bin/activate && cd $WORKDIR"
if [[ ! " ${VALID_STRATEGIES[@]} " =~ " ${ROUTING_STRATEGY} " ]]; then
echo "Error: Invalid routing strategy. Must be one of: ${VALID_STRATEGIES[*]}"
exit 1
fi
########################################################
# HTTP Server
########################################################
HTTP_CMD="TRD_LOG=DEBUG http"
tmux new-session -d -s "$SESSION_NAME-http"
tmux send-keys -t "$SESSION_NAME-http" "$INIT_CMD && $HTTP_CMD" C-m
SESSION_NAME="v"
WORKDIR="/workspace/examples/python_rs/llm/vllm"
INIT_CMD="source /opt/triton/venv/bin/activate && cd $WORKDIR"
########################################################
# LLMCTL
########################################################
LLMCTL_CMD="sleep 5 && llmctl http remove chat-model $MODEL_NAME && \
llmctl http add chat-model $MODEL_NAME $ENDPOINT_NAME && \
llmctl http list chat-model"
tmux new-session -d -s "$SESSION_NAME-llmctl"
tmux send-keys -t "$SESSION_NAME-llmctl" "$INIT_CMD && $LLMCTL_CMD" C-m
########################################################
# Processor
########################################################
# For now processor gets same args as worker, need to have them communicate over etcd
PROCESSOR_CMD="RUST_LOG=info python3 -m kv_router.processor \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
--enable-prefix-caching \
--block-size 64 \
--max-model-len 16384 "
tmux new-session -d -s "$SESSION_NAME-processor"
tmux send-keys -t "$SESSION_NAME-processor" "$INIT_CMD && $PROCESSOR_CMD" C-m
########################################################
# Router
########################################################
ROUTER_CMD="RUST_LOG=info python3 -m kv_router.router \
--model $MODEL_NAME \
--routing-strategy $ROUTING_STRATEGY \
--min-workers $NUM_WORKERS "
tmux new-session -d -s "$SESSION_NAME-router"
tmux send-keys -t "$SESSION_NAME-router" "$INIT_CMD && $ROUTER_CMD" C-m
########################################################
# Workers
########################################################
WORKER_CMD="RUST_LOG=info python3 -m kv_router.worker \
--model $MODEL_NAME \
--tokenizer $MODEL_NAME \
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from typing import AsyncIterator
import uvloop
from common.chat_processor import ChatProcessor, ProcessMixIn
from common.parser import parse_vllm_args
from common.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from transformers import AutoTokenizer
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from triton_distributed_rs._core import Client
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
router_client: Client,
workers_client: Client,
):
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer(engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.router_client = router_client
self.workers_client = workers_client
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def generate_responses(
self, engine_generator
) -> AsyncIterator[RequestOutput]:
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
yield RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
@triton_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, raw_request):
request_id = str(uuid.uuid4())
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_id_generator: AsyncIterator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
worker_id = (
await worker_id_generator.__anext__()
) # only one worker id is returned
worker_id = worker_id.data()
vllm_logger.info(f"Worker ID: {worker_id}")
if worker_id == "":
engine_generator = await self.workers_client.random(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
else:
engine_generator = await self.workers_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
uuid.UUID(worker_id).int,
)
output = self.generate_responses(engine_generator)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Set up clients to the router and workers.
Serve the triton-init.process.chat/completions endpoint.
"""
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate")
.client()
)
router_client = (
await runtime.namespace("triton-init")
.component("router")
.endpoint("generate")
.client()
)
preprocess_component = runtime.namespace("triton-init").component("process")
await preprocess_component.create_service()
preprocess_endpoint = preprocess_component.endpoint("chat/completions")
processor = Processor(engine_args, router_client, workers_client)
assert isinstance(processor, ProcessMixIn)
await preprocess_endpoint.serve_endpoint(processor.generate)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
......@@ -15,12 +15,12 @@
import asyncio
import uuid
from argparse import Namespace
from enum import Enum
from typing import AsyncIterator
import uvloop
from common.protocol import Response, TokenizedRequest
from common.protocol import Tokens
from triton_distributed_rs import (
DistributedRuntime,
KvRouter,
......@@ -29,6 +29,8 @@ from triton_distributed_rs import (
)
from vllm.logger import logger as vllm_logger
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
......@@ -43,19 +45,17 @@ class Router:
def __init__(
self,
router,
workers_client,
router: KvRouter,
routing_strategy: RoutingStrategy = RoutingStrategy.PREFIX,
):
vllm_logger.info(
f"Initializing KV Router with strategy: {routing_strategy.value}"
)
self.router = router
self.workers_client = workers_client
self.routing_strategy = routing_strategy
@triton_endpoint(TokenizedRequest, Response)
async def generate(self, request):
@triton_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
if self.routing_strategy == RoutingStrategy.PREFIX:
......@@ -70,35 +70,36 @@ class Router:
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if self.routing_strategy == RoutingStrategy.ROUND_ROBIN:
engine_generator = await self.workers_client.round_robin(
request.model_dump_json()
)
elif self.routing_strategy == RoutingStrategy.RANDOM or worker_id == "":
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
yield worker_id
else:
# extract back lease_id
engine_generator = await self.workers_client.direct(
request.model_dump_json(), uuid.UUID(worker_id).int
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@triton_worker()
async def worker(runtime: DistributedRuntime, args: Namespace):
"""
Set up the worker clients.
Serve the triton-init.router.generate endpoint.
"""
workers_client = (
await runtime.namespace("triton-init")
.component("vllm")
.endpoint("generate_from_tokens")
.endpoint("generate")
.client()
)
vllm_logger.info("Waiting for workers to be ready")
await workers_client.wait_for_endpoints()
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
vllm_logger.info("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info(
......@@ -112,22 +113,16 @@ async def worker(runtime: DistributedRuntime, args: Namespace):
)
# TODO Router is a fixed namespace separate from the others
kv_listener = runtime.namespace("router").component(
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
)
kv_listener = runtime.namespace("router").component(args.model_name)
await kv_listener.create_service()
router_component = runtime.namespace("triton-init").component("router")
await router_component.create_service()
router = None
if args.routing_strategy == RoutingStrategy.PREFIX:
router = KvRouter(runtime, kv_listener)
router = KvRouter(runtime, kv_listener)
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(
Router(router, workers_client, args.routing_strategy).generate
)
await endpoint.serve_endpoint(Router(router, args.routing_strategy).generate)
if __name__ == "__main__":
......@@ -149,6 +144,12 @@ if __name__ == "__main__":
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
args = parser.parse_args()
asyncio.run(worker(args))
......@@ -16,111 +16,75 @@
import asyncio
import os
import uuid
from typing import Optional
from typing import AsyncIterator
import uvloop
import vllm
from common.base_engine import BaseVllmEngine
from common.parser import parse_vllm_args
from common.protocol import Request, Response, TokenizedRequest
from triton_distributed_rs import (
DistributedRuntime,
KvRouter,
triton_endpoint,
triton_worker,
)
from common.protocol import MyRequestOutput, vLLMGenerateRequest
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
from vllm.logger import logger as vllm_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.sampling_params import RequestOutputKind
vllm_logger.info(f"VLLM_KV_CAPI_PATH: {os.environ['VLLM_KV_CAPI_PATH']}")
class VllmEngine:
class VllmEngine(BaseVllmEngine):
"""
Request handler for the generate endpoint
vLLM Inference Engine
"""
def __init__(self, engine_args: AsyncEngineArgs, router: KvRouter):
self.engine = vllm.AsyncLLMEngine.from_engine_args(engine_args)
self.router = router
self.tokenizer: Optional[AnyTokenizer] = None
def __init__(self, engine_args: AsyncEngineArgs):
self.engine_args = engine_args
super().__init__(engine_args)
# Pattern to initialize async object as python __init__ is not async
async def init(self):
self.tokenizer = await self.engine.get_tokenizer()
return self
@triton_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request) -> AsyncIterator:
if self.engine_client is None:
await self.initialize()
assert self.engine_client is not None, "engine_client was not initialized"
@triton_endpoint(TokenizedRequest, Response)
async def generate_from_tokens(self, request):
tokens_prompt = TokensPrompt(prompt_token_ids=request.tokens)
sampling_params = request.sampling_params
# rust HTTP requires Delta streaming
sampling_params.output_kind = RequestOutputKind.DELTA
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
async for response in self.engine.generate(
tokens_prompt, sampling_params, request_id
async for response in self.engine_client.generate(
request.engine_prompt, sampling_params, request.request_id
):
yield response.outputs[0].text
@triton_endpoint(Request, Response)
async def generate_from_prompt(self, request):
sampling_params = vllm.SamplingParams(**request.sampling_params)
request_id = str(uuid.uuid4())
async for response in self.engine.generate(
request.prompt, sampling_params, request_id
):
yield response.outputs[0].text
@triton_endpoint(Request, Response)
async def preprocess(self, request):
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized. Must run init().")
tokens = self.tokenizer.encode(request.prompt)
engine_generator = await self.router.generate(
TokenizedRequest(tokens=tokens, **request.model_dump()).model_dump_json()
)
async for resp in engine_generator:
yield resp.data()
# MyRequestOutput takes care of serializing the response as
# vLLM's RequestOutput is not serializable by default
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
@triton_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
Serve the triton-init.vllm.generate endpoint.
"""
worker_component = runtime.namespace("triton-init").component("vllm")
await worker_component.create_service()
preprocess_component = runtime.namespace("triton-init").component("preprocess")
await preprocess_component.create_service()
router_client = (
await runtime.namespace("triton-init")
.component("router")
.endpoint("generate")
.client()
)
worker_from_tokens_endpoint = worker_component.endpoint("generate_from_tokens")
worker_from_prompt_endpoint = worker_component.endpoint("generate")
preprocess_endpoint = preprocess_component.endpoint("generate")
worker_endpoint = worker_component.endpoint("generate")
# TODO Hack until we unify lease_id and worker_id
VLLM_WORKER_ID = uuid.UUID(int=worker_from_tokens_endpoint.lease_id())
# KV Publisher and Aggregator requires a UUID (str)
# KV Router requires a lease_id (int)
# This allows us to please both, until they are unified
# If VLLM_WORKER_ID is not set, KV Routing will fail
VLLM_WORKER_ID = uuid.UUID(int=worker_endpoint.lease_id())
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
vllm_engine = VllmEngine(engine_args, router_client)
vllm_engine = await vllm_engine.init()
vllm_engine = VllmEngine(engine_args)
await asyncio.gather(
worker_from_tokens_endpoint.serve_endpoint(vllm_engine.generate_from_tokens),
worker_from_prompt_endpoint.serve_endpoint(vllm_engine.generate_from_prompt),
preprocess_endpoint.serve_endpoint(vllm_engine.preprocess),
)
await worker_endpoint.serve_endpoint(vllm_engine.generate)
if __name__ == "__main__":
......
......@@ -19,6 +19,7 @@ import uuid
import uvloop
from common.base_engine import BaseVllmEngine
from common.chat_processor import ProcessMixIn
from common.parser import parse_vllm_args
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
from vllm.engine.arg_utils import AsyncEngineArgs
......@@ -29,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import logger as vllm_logger
class VllmEngine(BaseVllmEngine):
class VllmEngine(BaseVllmEngine, ProcessMixIn):
"""
Request handler for the generate endpoint
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment