Commit ab33729b authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

fix: Add missing util files to vllm example (#105)

parent b0655a34
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from contextlib import asynccontextmanager
from typing import ClassVar, Optional
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
class NATSQueue:
_instance: ClassVar[Optional["NATSQueue"]] = None
_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
def __init__(
self,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
self.nats_url = nats_server
self._nc: Optional[NATS] = None
self._js: Optional[JetStreamContext] = None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
self._subject = f"{self._stream_name}.*"
self.dequeue_timeout = dequeue_timeout
self._subscriber: Optional[JetStreamContext.PullSubscription] = None
@classmethod
@asynccontextmanager
async def get_instance(
cls,
*,
stream_name: str = "default",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
await cls._instance.connect()
try:
yield cls._instance
except Exception:
if cls._instance:
await cls._instance.close()
cls._instance = None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@classmethod
async def shutdown(cls):
"""Explicitly close the singleton instance if it exists"""
async with cls._lock:
if cls._instance:
await cls._instance.close()
cls._instance = None
async def connect(self):
"""Establish connection and create stream if needed"""
try:
if self._nc is None:
self._nc = NATS()
await self._nc.connect(self.nats_url)
self._js = self._nc.jetstream()
# Check if stream exists, if not create it
try:
await self._js.stream_info(self._stream_name)
except NotFoundError:
await self._js.add_stream(
name=self._stream_name, subjects=[self._subject]
)
# Create persistent subscriber
self._subscriber = await self._js.pull_subscribe(
f"{self._stream_name}.queue", durable="worker-group"
)
except NatsError as e:
await self.close()
raise ConnectionError(f"Failed to connect to NATS: {e}")
async def ensure_connection(self):
"""Ensure we have an active connection"""
if self._nc is None or self._nc.is_closed:
await self.connect()
async def close(self):
"""Close the connection when done"""
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._subscriber = None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task using msgspec-encoded data
"""
await self.ensure_connection()
try:
await self._js.publish(f"{self._stream_name}.queue", task_data) # type: ignore
except NatsError as e:
raise RuntimeError(f"Failed to enqueue task: {e}")
async def dequeue_task(self) -> Optional[bytes]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await self.ensure_connection()
try:
msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout) # type: ignore
if msgs:
msg = msgs[0]
await msg.ack()
return msg.data
return None
except asyncio.TimeoutError:
return None
except NatsError as e:
raise RuntimeError(f"Failed to dequeue task: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
import msgspec
from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.runtime import DistributedRuntime
METADATA_DIR = "/tmp/nixl"
@contextmanager
def temp_metadata_file(engine_id, metadata: NixlMetadata):
os.makedirs(METADATA_DIR, exist_ok=True)
path = f"{METADATA_DIR}/{engine_id}.nixl_meta"
with open(path, "wb") as f:
encoded = msgspec.msgpack.encode(metadata)
print(f"Size of encoded metadata: {len(encoded)}")
f.write(encoded)
try:
yield path
finally:
if os.path.exists(path):
os.remove(path)
def find_remote_metadata(engine_id):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata = []
for file in os.listdir(METADATA_DIR):
if file.endswith(".nixl_meta"):
if file.split(".")[0] != engine_id:
with open(os.path.join(METADATA_DIR, file), "rb") as f:
remote_metadata.append(
msgspec.msgpack.decode(f.read(), type=NixlMetadata)
)
return remote_metadata
class NixlMetadataStore:
NIXL_METADATA_KEY = "nixl_metadata"
def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
self._namespace = namespace
# TODO Remove metadata from etcd on delete
self._stored: set[str] = set()
self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client()
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata):
serialized_metadata = msgspec.msgpack.encode(metadata)
key = "/".join([self._key_prefix, engine_id])
await self._client.kv_put(key, serialized_metadata, None)
self._stored.add(engine_id)
async def get(self, engine_id) -> NixlMetadata:
try:
if engine_id in self._cached:
return self._cached[engine_id]
key = "/".join([self._key_prefix, engine_id])
key_values = await self._client.kv_get_prefix(key)
deserialized_metadata = None
for item in key_values:
deserialized_metadata = msgspec.msgpack.decode(
item["value"], type=NixlMetadata
)
break
if deserialized_metadata is None:
raise Exception("metadata not found in etcd")
self._cached[engine_id] = deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except Exception as e:
raise Exception("Error retrieving metadata for engine {engine_id}") from e
return deserialized_metadata
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import msgspec
from utils.nats_queue import NATSQueue
from vllm.remote_prefill import RemotePrefillRequest
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser.add_argument(
"--conditional-disagg",
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length of local prefill",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
return engine_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment