Unverified Commit f00d700e authored by Alec's avatar Alec Committed by GitHub
Browse files

refactor: remove old examples with old UX (#1899)

parent c7080419
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pydantic import BaseModel
from components.planner import start_planner # type: ignore[attr-defined]
from dynamo.planner.defaults import LoadPlannerDefaults
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
from dynamo.sdk.core.protocol.interface import ComponentType
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
class RequestType(BaseModel):
text: str
@service(
dynamo={
"namespace": "dynamo",
"component_type": ComponentType.PLANNER,
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
image=DYNAMO_IMAGE,
)
class Planner:
def __init__(self):
configure_dynamo_logging(service_name="Planner")
logger.info("Starting planner")
self.runtime = dynamo_context["runtime"]
config = ServiceConfig.get_instance()
# Get namespace directly from dynamo_context as it contains the active namespace
self.namespace = dynamo_context["namespace"]
config_instance = config.get("Planner", {})
self.args = argparse.Namespace(
namespace=self.namespace,
environment=config_instance.get(
"environment", LoadPlannerDefaults.environment
),
no_operation=config_instance.get(
"no-operation", LoadPlannerDefaults.no_operation
),
log_dir=config_instance.get("log-dir", LoadPlannerDefaults.log_dir),
adjustment_interval=config_instance.get(
"adjustment-interval", LoadPlannerDefaults.adjustment_interval
),
metric_pulling_interval=config_instance.get(
"metric-pulling-interval", LoadPlannerDefaults.metric_pulling_interval
),
max_gpu_budget=config_instance.get(
"max-gpu-budget", LoadPlannerDefaults.max_gpu_budget
),
min_endpoint=config_instance.get(
"min-endpoint", LoadPlannerDefaults.min_endpoint
),
decode_kv_scale_up_threshold=config_instance.get(
"decode-kv-scale-up-threshold",
LoadPlannerDefaults.decode_kv_scale_up_threshold,
),
decode_kv_scale_down_threshold=config_instance.get(
"decode-kv-scale-down-threshold",
LoadPlannerDefaults.decode_kv_scale_down_threshold,
),
prefill_queue_scale_up_threshold=config_instance.get(
"prefill-queue-scale-up-threshold",
LoadPlannerDefaults.prefill_queue_scale_up_threshold,
),
prefill_queue_scale_down_threshold=config_instance.get(
"prefill-queue-scale-down-threshold",
LoadPlannerDefaults.prefill_queue_scale_down_threshold,
),
decode_engine_num_gpu=config_instance.get(
"decode-engine-num-gpu", LoadPlannerDefaults.decode_engine_num_gpu
),
prefill_engine_num_gpu=config_instance.get(
"prefill-engine-num-gpu", LoadPlannerDefaults.prefill_engine_num_gpu
),
)
@async_on_start
async def async_init(self):
import asyncio
await asyncio.sleep(30)
logger.info("Calling start_planner")
await start_planner(self.runtime, self.args)
logger.info("Planner started")
@endpoint()
async def generate(self, request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
import sys
from pydantic import BaseModel
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(BaseModel):
text: str
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class PrefillWorker:
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self._loaded_metadata = set()
self.initialized = False
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.disable_async_output_proc is not True:
logger.info("Async output processing is not supported yet, setting to True")
self.engine_args.disable_async_output_proc = True
if self.engine_args.enforce_eager is not True:
logger.info("Prefill must be done eagerly, setting to True")
self.engine_args.enforce_eager = True
if self.engine_args.enable_prefix_caching is not False:
logger.info(
"Prefix caching is not supported yet in prefill worker, setting to False"
)
self.engine_args.enable_prefix_caching = False
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"]
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
self.task = asyncio.create_task(self.prefill_queue_handler())
def prefill_queue_handler_cb(fut):
try:
fut.result()
logger.info("prefill queue handler exited successfully")
except Exception as e:
logger.error(f"[ERROR] prefill queue handler failed: {e!r}")
sys.exit(1)
self.task.add_done_callback(prefill_queue_handler_cb)
self.shutdown_requested = False
# Set up signal handler for graceful shutdown
# TODO: move to dynamo sdk
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(self.graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("PrefillWorker initialized")
async def graceful_shutdown(self, runtime):
logger.info("Received shutdown signal, shutting down DistributedRuntime")
# first shutdown the vllm engine
self.shutdown_requested = True
await asyncio.wait_for(self.task, timeout=None)
# then shutdown the mock endpoint
runtime.shutdown()
logger.info("DistributedRuntime shutdown complete")
def shutdown_vllm_engine(self):
"""Shutdown the background loop"""
logger.info("Shutting down vllm engine")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("PrefillWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
namespace, _ = PrefillWorker.dynamo_address() # type: ignore
prefill_queue_stream_name = f"{namespace}_prefill_queue"
logger.info(
f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}"
)
self.initialized = True
# TODO: integrate prefill_queue to a dynamo endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
logger.info("prefill queue handler started")
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
logger.info(
f"Dequeued prefill request: {prefill_request.request_id}"
)
async for _ in self.generate(prefill_request):
pass
if self.shutdown_requested:
logger.info(
"Shutdown requested, checking if engine has any pending prefill sending requests"
)
while True:
if not await self.engine_client.has_unfinished_requests():
break
logger.info(
"Engine has pending prefill sending requests, rechecking in 1 second..."
)
await asyncio.sleep(1)
self.shutdown_vllm_engine()
break
async def generate(self, request: RemotePrefillRequest):
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {request.engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@endpoint()
async def mock(self, req: RequestType):
yield f"mock_response: {req}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import uuid
from enum import Enum
from typing import Any, AsyncIterator, Dict, List, Tuple, Union
from components.kv_router import Router
from components.worker import VllmWorker
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.check_worker import check_required_workers
from utils.protocol import LocalBlockHashes, MyRequestOutput, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.llm import KvMetricsAggregator, compute_block_hash_for_seq_py
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
@service(
dynamo={
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
worker = depends(VllmWorker)
router = depends(Router)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.min_workers = 1
self.request_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self.request_futures: Dict[str, asyncio.Future] = {}
self.num_worker_tasks = (
self.engine_args.router_num_threads
) # Number of worker tasks to process the queue
self.worker_tasks: List[asyncio.Task] = []
print(f"Processor init: {self.engine_args.router}")
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmWorker.dynamo_address() # type: ignore
self.worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
self.use_router = self.engine_args.router in (
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
)
if self.use_router:
router_ns, router_name = Router.dynamo_address() # type: ignore
self.router_client = (
await runtime.namespace(router_ns)
.component(router_name)
.endpoint("generate")
.client()
)
await check_required_workers(self.worker_client, self.min_workers)
kv_listener = runtime.namespace("dynamo").component("VllmWorker")
await kv_listener.create_service()
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
f"/{comp_ns}/processor/",
{"router": self.engine_args.router},
)
# Start multiple worker tasks to process the queue
self._start_worker_tasks()
def _start_worker_tasks(self):
"""Start multiple worker tasks to process the queue concurrently"""
# Clear any existing worker tasks
for task in self.worker_tasks:
if not task.done():
task.cancel()
self.worker_tasks = []
# Create new worker tasks
for i in range(self.num_worker_tasks):
task = asyncio.create_task(self._process_queue(worker_id=i))
self.worker_tasks.append(task)
logger.info(f"Started {self.num_worker_tasks} queue worker tasks")
async def _process_queue(self, worker_id: int):
"""Background task to process the request queue"""
logger.info(f"Queue worker {worker_id} started")
while True:
try:
# Get the next request from the queue
request_data = await self.request_queue.get()
# Process the request
try:
await self._process_request(request_data)
except Exception as e:
logger.error(f"Worker {worker_id}: Error processing request: {e}")
finally:
# Mark the task as done
self.request_queue.task_done()
except asyncio.CancelledError:
logger.info(f"Queue worker {worker_id} was cancelled")
break
except Exception as e:
logger.error(
f"Worker {worker_id}: Unexpected error in queue processing: {e}"
)
# Sleep briefly to avoid tight error loops
await asyncio.sleep(0.1)
async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics()
kv_load = {}
for end_point in metrics.endpoints:
worker_id = end_point.worker_id
kv_load[worker_id] = getattr(end_point, "gpu_cache_usage_perc", 0.0)
return kv_load
async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {}
for end_point in metrics.endpoints:
worker_id = end_point.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
request_type: RequestType,
):
request_id = str(uuid.uuid4())
logger.debug(f"Got raw request: {raw_request}")
# Create a future for this request
future: asyncio.Future[AsyncIterator[Any]] = asyncio.Future()
self.request_futures[request_id] = future
# Enqueue the request with minimal processing
await self.request_queue.put(
{
"request_id": request_id,
"raw_request": raw_request,
"request_type": request_type,
}
)
try:
# Wait for the future to complete and yield the results
generator = await future
async for response in generator:
yield response
finally:
# Clean up the future when done
if request_id in self.request_futures:
del self.request_futures[request_id]
async def _process_request(self, request_data: Dict[str, Any]):
"""Process a single request from the queue"""
request_id = request_data["request_id"]
raw_request = request_data["raw_request"]
request_type = request_data["request_type"]
try:
# Parse the raw request here instead of in _generate
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
# Create an async generator function to process this request
async def process_and_stream():
# TODO: queue request at processor when engines are full
router_mode = (await self.etcd_kv_cache.get("router")).decode()
self.use_router = router_mode in (
RouterType.KV,
RouterType.KV_LOAD,
RouterType.APPROX_KV,
)
prefix_hit_rate = 0.0 # Default value
if self.use_router:
token_ids = engine_prompt["prompt_token_ids"]
router_generator = await self.router_client.generate(
LocalBlockHashes(
hashes=compute_block_hash_for_seq_py(
token_ids, self.engine_args.block_size
),
tokens=token_ids,
num_tokens=len(token_ids),
).model_dump_json()
)
decision = await router_generator.__anext__()
worker_id, prefix_hit_rate = decision.data()
prefix_hit_rate = float(prefix_hit_rate)
# Create request object once with default prefix_hit_rate
request_obj = vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
if self.use_router:
if worker_id == "":
engine_generator = await self.worker_client.generate(
request_obj
)
else:
engine_generator = await self.worker_client.direct(
request_obj, int(worker_id)
)
elif router_mode == RouterType.RANDOM:
engine_generator = await self.worker_client.generate(request_obj)
elif router_mode == RouterType.ROUND_ROBIN:
engine_generator = await self.worker_client.round_robin(request_obj)
output_generator = self._generate_responses(
engine_generator, request_type
)
# Stream responses directly to the caller
async for response in await self._stream_response(
request, output_generator, request_id, conversation
):
yield response
# Set the future result to our async generator
if request_id in self.request_futures:
self.request_futures[request_id].set_result(process_and_stream())
except Exception as e:
logger.error(f"Error processing request {request_id}: {e}")
# Set exception on the future if it still exists
if (
request_id in self.request_futures
and not self.request_futures[request_id].done()
):
self.request_futures[request_id].set_exception(e)
async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
@endpoint(name="chat/completions")
async def chat_completions(self, raw_request: ChatCompletionRequest):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
# @endpoint()
# async def completions(self, raw_request: CompletionRequest):
# async for response in self._generate(raw_request, RequestType.COMPLETION):
# yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
from components.disagg_router import PyDisaggregatedRouter
from components.prefill_worker import PrefillWorker
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind
from dynamo.llm import ForwardPassMetrics, KvStats, WorkerMetricsPublisher, WorkerStats
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmWorker:
prefill_worker = depends(PrefillWorker)
def __init__(self):
self.client = None
self.disaggregated_router: PyDisaggregatedRouter = None # type: ignore
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.do_remote_prefill = self.engine_args.remote_prefill
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self.namespace, _ = VllmWorker.dynamo_address() # type: ignore
self._prefill_queue_stream_name = f"{self.namespace}_prefill_queue"
logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
if self.engine_args.remote_prefill:
if self.engine_args.enable_chunked_prefill is not False:
logger.info("Chunked prefill is not supported yet, setting to False")
self.engine_args.enable_chunked_prefill = False
if self.engine_args.preemption_mode != "swap":
logger.info("Preemption mode is not supported yet, setting to swap")
self.engine_args.preemption_mode = "swap"
if self.engine_args.pipeline_parallel_size != 1:
logger.info("Pipeline parallel size is not supported yet, setting to 1")
self.engine_args.pipeline_parallel_size = 1
if self.engine_args.router in (RouterType.KV, RouterType.APPROX_KV):
if not self.engine_args.enable_prefix_caching:
logger.info(
"When using KV router, prefix caching must be enabled, setting to True"
)
self.engine_args.enable_prefix_caching = True
VLLM_WORKER_ID = dynamo_context["endpoints"][0].lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
os.environ["VLLM_KV_NAMESPACE"] = "dynamo"
os.environ["VLLM_KV_COMPONENT"] = class_name
self.metrics_publisher = WorkerMetricsPublisher()
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
@async_on_start
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
self.engine_client.set_metrics_publisher(self.metrics_publisher)
# Initially send dummy metrics to kick start,
# vLLM will not update stat until forward pass is triggered
worker_stats = WorkerStats(
0, # request_active_slots
1024, # request_total_slots
0, # num_requests_waiting
None, # data_parallel_rank
)
kv_stats = KvStats(
0, # kv_active_blocks
1024, # kv_total_blocks
0.0, # gpu_cache_usage_perc
0.0, # gpu_prefix_cache_hit_rate
)
metrics = ForwardPassMetrics(
worker_stats=worker_stats,
kv_stats=kv_stats,
spec_decode_stats=None,
)
self.metrics_publisher.publish(metrics)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(
lambda _: logger.info("metrics publisher endpoint created")
)
runtime = dynamo_context["runtime"]
if self.engine_args.remote_prefill:
metadata = self.engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo", runtime)
await metadata_store.put(metadata.engine_id, metadata)
if self.engine_args.conditional_disagg:
self.disaggregated_router = PyDisaggregatedRouter(
runtime,
self.namespace,
max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
)
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = None
# Set up signal handler for graceful shutdown
# TODO: move to dynamo sdk
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(self.graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("VllmWorker has been initialized")
async def graceful_shutdown(self, runtime):
logger.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logger.info("DistributedRuntime shutdown complete")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("VllmWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
logger.info("Creating metrics publisher endpoint with primary lease")
await self.metrics_publisher.create_endpoint(component)
def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint
async def callback(request: RemotePrefillRequest):
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
await prefill_queue.enqueue_prefill_request(request)
return callback
# TODO: use the same child lease for metrics publisher endpoint and generate endpoint
@endpoint()
async def generate(self, request: vLLMGenerateRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely
if self.disaggregated_router is not None:
async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server,
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]),
request.prefix_hit_rate,
prefill_queue_size,
)
else:
# always prefill remotely if no disaggregated router is provided
disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else:
remote_prefill_params = None
logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
# rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate(
prompt=request.engine_prompt,
sampling_params=request.sampling_params,
request_id=request.request_id,
remote_prefill_params=remote_prefill_params,
):
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
router: round-robin
router-num-threads: 4
common-configs: [model, block-size, max-model-len]
VllmWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len]
Planner:
environment: local
no-operation: true
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
router: kv
block-size: 64
max-model-len: 16384
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
common-configs: [model, block-size, max-model-len, router]
Router:
min-workers: 1
softmax-sample: true
common-configs: [model, block-size, router]
VllmWorker:
enforce-eager: true
max-num-batched-tokens: 16384
enable-prefix-caching: true
tensor-parallel-size: 1
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, router, kv-transfer-config]
Planner:
environment: local
no-operation: true
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000
Processor:
router: round-robin
common-configs: [model, block-size]
VllmWorker:
remote-prefill: true
conditional-disagg: true
max-local-prefill-length: 10
max-prefill-queue-size: 2
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, kv-transfer-config]
PrefillWorker:
max-num-batched-tokens: 16384
ServiceArgs:
workers: 1
resources:
gpu: '1'
common-configs: [model, block-size, max-model-len, kv-transfer-config]
Planner:
environment: local
no-operation: true
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment