Commit b92834c8 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

chore: removing outdated examples (#202)

parent fd79234f
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import msgspec
from utils.nats_queue import NATSQueue
from vllm.remote_prefill import RemotePrefillRequest
class PrefillQueue(NATSQueue):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def __init__(
self,
stream_name="prefill_queue",
nats_server: str = "nats://localhost:4222",
dequeue_timeout: float = 1,
):
super().__init__(
stream_name=stream_name,
nats_server=nats_server,
dequeue_timeout=dequeue_timeout,
)
async def enqueue_prefill_request(
self, prefill_request: RemotePrefillRequest
) -> None:
encoded_request = msgspec.json.encode(prefill_request)
await self.enqueue_task(encoded_request)
async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
encoded_request = await self.dequeue_task()
if encoded_request is not None:
prefill_request = msgspec.json.decode(
encoded_request, type=RemotePrefillRequest
)
return prefill_request
else:
return None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs.data import TokensPrompt
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
class Request(BaseModel):
prompt: str
sampling_params: dict
class Tokens(BaseModel):
tokens: list[int]
class PrefillRequest(Request):
request_id: str
class Response(BaseModel):
text: str
class PrefillResponse(BaseModel):
prefilled: bool
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams type to add a dummy core schema so pydantic can validate it
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
engine_prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
prefix_hit_rate: Optional[float] = 0.0
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_vllm_args() -> AsyncEngineArgs:
parser = FlexibleArgumentParser()
parser.add_argument(
"--router",
type=str,
choices=["random", "round-robin", "kv"],
default="random",
help="Router type to use for scheduling requests to workers",
)
parser.add_argument(
"--remote-prefill", action="store_true", help="Enable remote prefill"
)
parser.add_argument(
"--conditional-disagg",
action="store_true",
help="Use disaggregated router to decide whether to prefill locally or remotely",
)
parser.add_argument(
"--max-local-prefill-length",
type=int,
default=1000,
help="Maximum length of local prefill",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.router = args.router
engine_args.remote_prefill = args.remote_prefill
engine_args.conditional_disagg = args.conditional_disagg
engine_args.max_local_prefill_length = args.max_local_prefill_length
return engine_args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import uvloop
from disagg_router import PyDisaggregatedRouter
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
class RequestHandler:
def __init__(
self,
model_name: str,
engine_client: EngineClient,
prefill_client,
do_remote_prefill: bool,
disaggregated_router: PyDisaggregatedRouter = None,
):
self.model_name = model_name
self.client = engine_client
self.prefill_client = prefill_client
self.openai_serving_chat = None
self.initialized = False
self.do_remote_prefill = (
do_remote_prefill # remote prefill is still controlled by the router
)
self.disaggregated_router = disaggregated_router
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = model_name
vllm_logger.info(
"Prefill queue: %s:%s",
self._prefill_queue_nats_server,
self._prefill_queue_stream_name,
)
print("RequestHandler initialized")
def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint
async def callback(request: RemotePrefillRequest):
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
@dynamo_endpoint(vLLMGenerateRequest, MyRequestOutput)
async def generate(self, request):
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
disagg_router_decision = self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
)
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
vllm_logger.debug(
"Prefilling remotely for request %s with length %s",
request.request_id,
len(request.engine_prompt["prompt_token_ids"]),
)
else:
remote_prefill_params = None
vllm_logger.debug(
"Prefilling locally for request %s with length %s",
request.request_id,
len(request.engine_prompt["prompt_token_ids"]),
)
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.client.generate(
prompt=request.engine_prompt,
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("dynamo-init").component("vllm")
await component.create_service()
endpoint = component.endpoint("generate")
if engine_args.remote_prefill:
prefill_client = (
await runtime.namespace("dynamo-init")
.component("prefill")
.endpoint("generate")
.client()
)
else:
prefill_client = None
if engine_args.router == "kv":
# TODO: do we need these env vars?
VLLM_WORKER_ID = endpoint.lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
vllm_logger.info("Generate endpoint ID: %s", VLLM_WORKER_ID)
VLLM_KV_NAMESPACE = "dynamo-init"
os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
VLLM_KV_COMPONENT = "vllm"
os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
metrics_publisher = KvMetricsPublisher()
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
served_model_name = (
engine_args.served_model_name
if engine_args.served_model_name is not None
else "vllm"
)
if engine_args.router == "kv":
engine_client.set_metrics_publisher(metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
metrics_publisher.publish(
0, # request_active_slots
1024, # request_total_slots
0, # kv_active_blocks
1024, # kv_total_blocks
0, # num_requests_waiting
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
if engine_args.remote_prefill:
metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
if engine_args.conditional_disagg:
disaggregated_router = PyDisaggregatedRouter(
runtime,
served_model_name,
max_local_prefill_length=engine_args.max_local_prefill_length,
)
else:
disaggregated_router = None
endpoints = [
endpoint.serve_endpoint(
RequestHandler(
model_name=served_model_name,
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=engine_args.remote_prefill,
disaggregated_router=disaggregated_router,
).generate
)
]
if engine_args.router == "kv":
endpoints.append(metrics_publisher.create_endpoint(component))
await asyncio.gather(*endpoints)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.remote_prefill:
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.preemption_mode != "swap":
print("Preemption mode is not supported yet, setting to swap")
engine_args.preemption_mode = "swap"
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import pytest
pytestmark = pytest.mark.gpu
def test_detect_gpu():
try:
result = subprocess.run(
["nvidia-smi"], capture_output=True, text=True, check=True
)
print("\nAvailable GPUs:")
print(result.stdout)
assert "NVIDIA" in result.stdout, "No NVIDIA GPUs found in nvidia-smi output"
except subprocess.CalledProcessError as e:
pytest.fail(f"nvidia-smi command failed with error: {e}")
except FileNotFoundError:
pytest.fail(
"nvidia-smi command not found. Ensure NVIDIA drivers are properly installed."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment