"lib/vscode:/vscode.git/clone" did not exist on "c7cdc8cdb2fb255d94656c9022d8835432794013"
Unverified Commit e79c4932 authored by leo-cf-tian's avatar leo-cf-tian Committed by GitHub
Browse files

feat: Add vLLM start/stop profile endpoints (#8068)


Signed-off-by: default avatarLeo Tian <lctian@nvidia.com>
parent 81819fbc
...@@ -663,6 +663,34 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -663,6 +663,34 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
logger.error(f"Failed to wake up engine: {e}") logger.error(f"Failed to wake up engine: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def start_profile(self, body: dict) -> dict:
"""Start profiling on the engine.
Args:
body: Dict with profiling parameters. Supported keys:
- profile_prefix (str|None): Optional prefix for profile output files.
"""
profile_prefix = body.get("profile_prefix")
try:
await self.engine_client.start_profile(profile_prefix=profile_prefix)
return {"status": "ok", "message": "Profiling started"}
except Exception as e:
logger.error(f"Failed to start profiling: {e}")
return {"status": "error", "message": str(e)}
async def stop_profile(self, body: dict) -> dict:
"""Stop profiling on the engine.
Args:
body: Unused, but required for handler signature.
"""
try:
await self.engine_client.stop_profile()
return {"status": "ok", "message": "Profiling stopped"}
except Exception as e:
logger.error(f"Failed to stop profiling: {e}")
return {"status": "error", "message": str(e)}
@abstractmethod @abstractmethod
def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]: def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
raise NotImplementedError raise NotImplementedError
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from dynamo.vllm.handlers import BaseWorkerHandler
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
class _TestWorkerHandler(BaseWorkerHandler):
async def generate(self, request, context):
yield {}
def _make_handler() -> _TestWorkerHandler:
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.engine_client = SimpleNamespace(
start_profile=AsyncMock(),
stop_profile=AsyncMock(),
)
return handler
@pytest.mark.asyncio
async def test_start_profile_calls_engine_with_prefix():
handler = _make_handler()
result = await handler.start_profile({"profile_prefix": "prefix"})
assert result["status"] == "ok"
handler.engine_client.start_profile.assert_awaited_once_with(
profile_prefix="prefix"
)
@pytest.mark.asyncio
async def test_start_profile_without_prefix_passes_none():
handler = _make_handler()
result = await handler.start_profile({})
assert result["status"] == "ok"
handler.engine_client.start_profile.assert_awaited_once_with(profile_prefix=None)
@pytest.mark.asyncio
async def test_stop_profile_calls_engine():
handler = _make_handler()
result = await handler.stop_profile({})
assert result["status"] == "ok"
handler.engine_client.stop_profile.assert_awaited_once_with()
...@@ -26,7 +26,12 @@ from dynamo.runtime import DistributedRuntime ...@@ -26,7 +26,12 @@ from dynamo.runtime import DistributedRuntime
from .args import Config from .args import Config
from .constants import DisaggregationMode from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker from .handlers import (
BaseWorkerHandler,
DecodeWorkerHandler,
PrefillWorkerHandler,
get_dp_range_for_worker,
)
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
from .multimodal_handlers import EncodeWorkerHandler from .multimodal_handlers import EncodeWorkerHandler
from .publisher import StatLoggerFactory from .publisher import StatLoggerFactory
...@@ -361,13 +366,8 @@ class WorkerFactory: ...@@ -361,13 +366,8 @@ class WorkerFactory:
component_name=config.component, component_name=config.component,
) )
# Register sleep/wake_up engine routes # Register engine routes
runtime.register_engine_route("sleep", handler.sleep) self.register_engine_routes(runtime, handler)
runtime.register_engine_route("wake_up", handler.wake_up)
runtime.register_engine_route("scale_elastic_ep", handler.scale_elastic_ep)
logger.info(
"Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep"
)
# Parse endpoint types from --endpoint-types flag # Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.endpoint_types) model_type = parse_endpoint_types(config.endpoint_types)
...@@ -576,13 +576,8 @@ class WorkerFactory: ...@@ -576,13 +576,8 @@ class WorkerFactory:
component_name=config.component, component_name=config.component,
) )
# Register sleep/wake_up engine routes # Register engine routes
runtime.register_engine_route("sleep", handler.sleep) self.register_engine_routes(runtime, handler)
runtime.register_engine_route("wake_up", handler.wake_up)
runtime.register_engine_route("scale_elastic_ep", handler.scale_elastic_ep)
logger.info(
"Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep"
)
await self._maybe_wait_for_failover_lock(handler, runtime, config) await self._maybe_wait_for_failover_lock(handler, runtime, config)
...@@ -666,3 +661,21 @@ class WorkerFactory: ...@@ -666,3 +661,21 @@ class WorkerFactory:
logger.info("Connected to encode workers") logger.info("Connected to encode workers")
return encode_worker_client return encode_worker_client
return None return None
def register_engine_routes(
self, runtime: DistributedRuntime, handler: BaseWorkerHandler
) -> None:
"""Register all engine routes for this handler.
Args:
runtime: The DistributedRuntime instance to register routes on.
"""
runtime.register_engine_route("start_profile", handler.start_profile)
runtime.register_engine_route("stop_profile", handler.stop_profile)
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
runtime.register_engine_route("scale_elastic_ep", handler.scale_elastic_ep)
logger.info(
"Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep, /engine/start_profile, /engine/stop_profile"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment