Unverified Commit 3a925951 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

refactor: [vLLM] move prefill / decode worker initialization into worker factory (#7367)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent ba2d9684
......@@ -37,6 +37,7 @@ from dynamo.llm import (
register_model,
unregister_model,
)
from dynamo.runtime import Client
from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor
......@@ -1341,7 +1342,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None,
):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__(
runtime,
engine,
......@@ -1556,7 +1563,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None,
):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__(
runtime,
engine,
......
......@@ -17,9 +17,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.prometheus import (
LLMBackendMetrics,
......@@ -34,15 +32,14 @@ from dynamo.llm import (
fetch_model,
register_model,
)
from dynamo.runtime import DistributedRuntime, Endpoint
from dynamo.runtime import Endpoint
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.worker_factory import WorkerFactory
from . import envs
from .args import Config, _uses_dynamo_connector, parse_args
from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
from .handlers import get_dp_range_for_worker
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
from .snapshot import prepare_snapshot_engine
......@@ -142,6 +139,8 @@ async def worker() -> None:
use_kv_events=config.use_kv_events,
)
# [gluo FIXME] should be after init() below? 'shutdown_endpoints' are populated
# there
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
# Route to appropriate initialization based on config flags
......@@ -151,6 +150,8 @@ async def worker() -> None:
setup_vllm_engine_fn=setup_vllm_engine,
setup_kv_event_publisher_fn=setup_kv_event_publisher,
register_vllm_model_fn=register_vllm_model,
setup_fpm_relay_fn=setup_fpm_relay,
setup_metrics_collection_fn=setup_metrics_collection,
)
await factory.create(
runtime,
......@@ -159,23 +160,9 @@ async def worker() -> None:
shutdown_endpoints,
snapshot_engine=snapshot_engine,
)
logger.debug("multimodal worker completed")
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
await init_prefill(
runtime,
config,
shutdown_event,
snapshot_engine=snapshot_engine,
)
logger.debug("init_prefill completed")
logger.debug("worker init completed")
else:
await init(
runtime,
config,
shutdown_event,
snapshot_engine=snapshot_engine,
)
logger.debug("init completed")
raise ValueError("Unsupported worker configuration")
logger.debug("Worker function completed, exiting...")
......@@ -625,378 +612,6 @@ async def register_vllm_model(
)
async def init_prefill(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
snapshot_engine: Optional[
tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
fpm_worker_id = str(generate_endpoint.connection_id())
if snapshot_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = snapshot_engine
# TODO: The scheduler in the child process still has worker_id=""
# because the engine was forked before the runtime existed.
# Propagating the new ID to the child requires shared memory or
# a restart of the EngineCore process.
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
else:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = setup_vllm_engine(config, fpm_worker_id=fpm_worker_id)
handler = PrefillWorkerHandler(
runtime,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
)
handler.add_temp_dir(prometheus_temp_dir)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = _consolidator_eps[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
# Set up forward pass metrics relay (child ZMQ -> event plane).
# In checkpoint mode the engine was created before the runtime, so
# ForwardPassMetrics.worker_id will be empty (relay still works).
fpm_relays = setup_fpm_relay(generate_endpoint, vllm_config)
if fpm_relays:
handler.fpm_relays = fpm_relays
setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
await register_vllm_model(
model_input,
ModelType.Prefill,
generate_endpoint,
config,
engine_client,
vllm_config,
)
health_check_payload = VllmPrefillHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for prefill worker")
await asyncio.gather(
# for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, # type: ignore
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
),
)
logger.debug("serve_endpoint completed for prefill worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
logger.debug("Cleaning up prefill worker")
handler.cleanup()
async def init(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
snapshot_engine: Optional[
tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
shutdown_endpoints[:] = [
generate_endpoint,
clear_endpoint,
]
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.load_lora"
)
unload_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.unload_lora"
)
list_loras_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.list_loras"
)
shutdown_endpoints.extend(
[
load_lora_endpoint,
unload_lora_endpoint,
list_loras_endpoint,
]
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
fpm_worker_id = str(generate_endpoint.connection_id())
if snapshot_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = snapshot_engine
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
# Factory is created after unpack so component_gauges is available
factory = StatLoggerFactory(
endpoint=generate_endpoint,
component_gauges=component_gauges,
)
else:
# Factory is created without component_gauges; setup_vllm_engine() will
# create the gauges after setup_multiprocess_prometheus() and set them
# on the factory before vLLM calls create_stat_logger().
factory = StatLoggerFactory(
endpoint=generate_endpoint,
)
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = setup_vllm_engine(config, factory, fpm_worker_id=fpm_worker_id)
# TODO Hack to get data, move this to registering in TBD
factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
factory.init_publish()
handler = DecodeWorkerHandler(
runtime,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
)
handler.add_temp_dir(prometheus_temp_dir)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = _consolidator_eps[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publisher for prefix caching if enabled
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
# Set up forward pass metrics relay (child ZMQ -> event plane).
# In checkpoint mode the engine was created before the runtime, so
# ForwardPassMetrics.worker_id will be empty (relay still works).
fpm_relays = setup_fpm_relay(generate_endpoint, vllm_config)
if fpm_relays:
handler.fpm_relays = fpm_relays
setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.endpoint_types)
logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
# Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.endpoint_types:
logger.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
)
await register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
)
health_check_payload = VllmHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for decode worker")
model_metrics_labels = [
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
]
serve_tasks = [
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
metrics_labels=model_metrics_labels,
),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=model_metrics_labels,
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=model_metrics_labels,
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=model_metrics_labels,
),
]
)
await asyncio.gather(*serve_tasks)
logger.debug("serve_endpoint completed for decode worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
logger.debug("Cleaning up decode worker")
# Cleanup background tasks
handler.cleanup()
def get_engine_cache_info(engine: AsyncLLM) -> dict[str, Any]:
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
......
......@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, Mock
import pytest
from dynamo.vllm.constants import DisaggregationMode
from dynamo.vllm.worker_factory import EngineSetupResult, WorkerFactory
pytestmark = [
......@@ -25,7 +26,8 @@ def _make_config(**overrides) -> Mock:
"multimodal_worker": False,
"multimodal_decode_worker": False,
"omni": False,
"is_prefill_worker": False,
"route_to_encoder": False,
"disaggregation_mode": DisaggregationMode.AGGREGATED,
}
defaults.update(overrides)
return Mock(**defaults)
......@@ -34,31 +36,55 @@ def _make_config(**overrides) -> Mock:
class TestHandles:
"""Test WorkerFactory.handles() config detection."""
def test_multimodal_encode_worker(self) -> None:
config = _make_config(multimodal_encode_worker=True)
# Legacy worker config
@pytest.mark.parametrize("route_to_encode", [True, False])
def test_multimodal_encode_worker(self, route_to_encode: bool) -> None:
# 'route_to_encoder' can be passed, the worker creation may ignore it.
config = _make_config(
multimodal_encode_worker=True, route_to_encoder=route_to_encode
)
assert WorkerFactory.handles(config)
def test_multimodal_worker(self) -> None:
config = _make_config(multimodal_worker=True)
@pytest.mark.parametrize("route_to_encode", [True, False])
def test_multimodal_worker(self, route_to_encode: bool) -> None:
config = _make_config(multimodal_worker=True, route_to_encoder=route_to_encode)
assert WorkerFactory.handles(config)
def test_multimodal_decode_worker(self) -> None:
config = _make_config(multimodal_decode_worker=True)
@pytest.mark.parametrize("route_to_encode", [True, False])
def test_multimodal_decode_worker(self, route_to_encode: bool) -> None:
config = _make_config(
multimodal_decode_worker=True, route_to_encoder=route_to_encode
)
assert WorkerFactory.handles(config)
def test_no_multimodal_flags(self) -> None:
config = _make_config()
assert not WorkerFactory.handles(config)
# Tests for no standalone encode worker setting
@pytest.mark.parametrize("route_to_encode", [True, False])
def test_no_multimodal_flags(self, route_to_encode: bool) -> None:
config = _make_config(route_to_encoder=route_to_encode)
assert WorkerFactory.handles(config)
def test_omni_not_handled(self) -> None:
config = _make_config(omni=True)
@pytest.mark.parametrize("route_to_encode", [True, False])
def test_prefill(self, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.PREFILL,
route_to_encoder=route_to_encode,
)
# [gluo NOTE] due to current limitation, see 'WorkerFactory._validate_config()'.
if route_to_encode:
assert not WorkerFactory.handles(config)
else:
assert WorkerFactory.handles(config)
def test_prefill_only_not_handled(self) -> None:
config = _make_config(is_prefill_worker=True)
assert not WorkerFactory.handles(config)
@pytest.mark.parametrize("route_to_encode", [True, False])
def test_decode(self, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.DECODE,
route_to_encoder=route_to_encode,
)
assert WorkerFactory.handles(config)
@pytest.mark.asyncio
class TestCreate:
"""Test WorkerFactory.create() routing."""
......@@ -68,41 +94,90 @@ class TestCreate:
setup_vllm_engine_fn=Mock(),
setup_kv_event_publisher_fn=Mock(),
register_vllm_model_fn=AsyncMock(),
setup_fpm_relay_fn=Mock(),
setup_metrics_collection_fn=Mock(),
)
factory._create_multimodal_encode_worker = AsyncMock() # type: ignore[assignment]
factory._create_multimodal_worker = AsyncMock() # type: ignore[assignment]
factory._create_prefill_worker = AsyncMock() # type: ignore[assignment]
factory._create_decode_worker = AsyncMock() # type: ignore[assignment]
return factory
@pytest.mark.asyncio
async def test_routes_to_multimodal_encode(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_encode_worker=True)
# Tests for non-legacy worker config, 'route_to_encode' is worker internal config
# so either case should hit creation function.
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_aggregated(
self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(route_to_encoder=route_to_encode)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_decode_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_prefill(self, factory: WorkerFactory, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.PREFILL,
route_to_encoder=route_to_encode,
)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_prefill_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_decode(self, factory: WorkerFactory, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.DECODE,
route_to_encoder=route_to_encode,
)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_decode_worker.assert_called_once() # type: ignore[union-attr]
# Tests with legacy worker config.
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_routes_to_multimodal_encode(
self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(
multimodal_encode_worker=True, route_to_encoder=route_to_encode
)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_encode_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_routes_to_multimodal_worker(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_worker=True)
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_routes_to_multimodal_worker(
self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(multimodal_worker=True, route_to_encoder=route_to_encode)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_routes_multimodal_decode_worker(
self, factory: WorkerFactory
self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(multimodal_decode_worker=True)
config = _make_config(
multimodal_decode_worker=True, route_to_encoder=route_to_encode
)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_passes_snapshot_engine(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_worker=True)
runtime = Mock()
......@@ -131,9 +206,3 @@ class TestCreate:
shutdown_endpoints,
snapshot_engine=snapshot_engine,
)
@pytest.mark.asyncio
async def test_raises_when_no_multimodal_flag(self, factory: WorkerFactory) -> None:
config = _make_config()
with pytest.raises(ValueError, match="no multimodal worker type set"):
await factory.create(Mock(), config, asyncio.Event(), [])
......@@ -8,26 +8,36 @@ import logging
from collections.abc import Awaitable, Callable
from typing import Any, Optional
from vllm.config import VllmConfig
from vllm.v1.engine.async_llm import AsyncLLM
from dynamo import prometheus_names
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.llm import ModelInput
from dynamo.common.utils.prometheus import LLMBackendMetrics
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime
from .args import Config
from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
from .multimodal_handlers import (
EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
)
from .publisher import StatLoggerFactory
logger = logging.getLogger(__name__)
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir, component_gauges)
EngineSetupResult = tuple[Any, Any, Any, Any, Any]
EngineSetupResult = tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
SetupVllmEngineFn = Callable[..., EngineSetupResult]
SetupKvEventPublisherFn = Callable[..., Optional[Any]]
RegisterVllmModelFn = Callable[..., Awaitable[None]]
SetupFpmRelayFn = Callable[..., Optional[list]]
SetupMetricsCollectionFn = Callable[..., None]
class WorkerFactory:
......@@ -38,18 +48,87 @@ class WorkerFactory:
setup_vllm_engine_fn: SetupVllmEngineFn,
setup_kv_event_publisher_fn: SetupKvEventPublisherFn,
register_vllm_model_fn: RegisterVllmModelFn,
setup_fpm_relay_fn: SetupFpmRelayFn,
setup_metrics_collection_fn: SetupMetricsCollectionFn,
):
self.setup_vllm_engine = setup_vllm_engine_fn
self.setup_kv_event_publisher = setup_kv_event_publisher_fn
self.register_vllm_model = register_vllm_model_fn
self.setup_fpm_relay = setup_fpm_relay_fn
self.setup_metrics_collection = setup_metrics_collection_fn
@staticmethod
def handles(config: Config) -> bool:
"""Return True if this factory handles the given config."""
return bool(
config.multimodal_encode_worker
or config.multimodal_worker
or config.multimodal_decode_worker
try:
WorkerFactory._validate_config(config)
return True
except (ValueError, NotImplementedError) as e:
logger.debug(
f"WorkerFactory cannot handle config: {e}, provided config: {WorkerFactory._config_str(config)}"
)
return False
@staticmethod
def _config_str(config: Config) -> str:
"""Helper function to format config for logging."""
return (
"{ "
f"multimodal_worker: {config.multimodal_worker}, "
f"multimodal_decode_worker: {config.multimodal_decode_worker}, "
f"multimodal_encode_worker: {config.multimodal_encode_worker}, "
f"disaggregation_mode: {config.disaggregation_mode}, "
f"route_to_encoder: {config.route_to_encoder}"
" }"
)
@staticmethod
def _validate_config(config: Config) -> None:
# [gluo FIXME] We are validating config combination for
# the transition away from "legacy" E/PD creation, which uses specialized
# P/D classes.
# In the future, we should rely on Dynamo runtime for P/D orchestration,
# thus the P/D worker in 'handlers.py' should soon be extended to support
# remote encode workflow, i.e. aware of encode worker client and perform remote
# encode when needed.
# Until then, we have validation on disaggregation mode and multimodal settings
# to guide user to use legacy mode for unsupported combination (see FIXME below).
legacy_multimodal_llm_worker = (
config.multimodal_worker or config.multimodal_decode_worker
)
if legacy_multimodal_llm_worker:
# [gluo] Sanity check, may be removed once legacy mode is removed.
# In the legacy mode, the specialized worker have P -> (optional D),
# so multimodal worker can be AGGREGATED or PREFILL, while
# multimodal decode worker can only be DECODE.
if (
config.multimodal_decode_worker
and config.disaggregation_mode == DisaggregationMode.PREFILL
):
raise ValueError(
"Multimodal decode worker with PREFILL disaggregation mode is not supported."
)
if (
config.multimodal_worker
and config.disaggregation_mode == DisaggregationMode.DECODE
):
raise ValueError(
"Multimodal worker with DECODE disaggregation mode is not supported."
)
# [gluo FIXME]
# 'route_to_encoder' hints standalone encode worker is used
# 'legacy_multimodal_llm_worker == False' hints Dynamo runtime will orchestrate
# P/D disagg and base P/D worker class should be used.
# In such a case, we can't use factory for P/D disaggregation modes because
# the current Dynamo runtime orchestrator is not aware of the extra mm data
# passing between P and D, P/D classes can't consume it properly untill
# the protocol is updated.
elif (
config.route_to_encoder
and config.disaggregation_mode == DisaggregationMode.PREFILL
):
raise NotImplementedError(
"Dynamo orchestrated disaggregated prefill worker, combined with remote encode worker is not supported."
)
async def create(
......@@ -62,11 +141,23 @@ class WorkerFactory:
) -> None:
"""Create the appropriate multimodal worker based on config flags."""
# Standalone encode worker
if config.multimodal_encode_worker:
await self._create_multimodal_encode_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
elif config.multimodal_worker or config.multimodal_decode_worker:
return
# [gluo WIP] This conditional should only be within worker creation,
# put here as some LLM worker setting is not compatible with
# standalone encode worker, so check supported combinations early.
# LLM connects to standalone encode worker
legacy_multimodal_llm_worker = (
config.multimodal_worker or config.multimodal_decode_worker
)
# Create P/D worker, internally may use remote encode worker for multimodal work
if legacy_multimodal_llm_worker:
await self._create_multimodal_worker(
runtime,
config,
......@@ -74,10 +165,28 @@ class WorkerFactory:
shutdown_endpoints,
snapshot_engine=snapshot_engine,
)
return
# [gluo FIXME] currently refactoring DecodeWorkerHandler from main.py for
# the use case of only disaggregating encode worker, so adding only decode
# worker creation for now, which is used in DisaggregationMode.AGGREGATED.
if config.disaggregation_mode == DisaggregationMode.PREFILL:
await self._create_prefill_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
snapshot_engine=snapshot_engine,
)
else:
raise ValueError(
"WorkerFactory.create() called but no multimodal worker type set in config"
await self._create_decode_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
snapshot_engine=snapshot_engine,
)
return
async def _create_multimodal_worker(
self,
......@@ -139,14 +248,9 @@ class WorkerFactory:
) = self.setup_vllm_engine(config)
# Set up encode worker client when routing to encoder is enabled
encode_worker_client = None
if config.route_to_encoder:
encode_worker_client = await runtime.endpoint(
f"{config.namespace}.encoder.generate"
).client()
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers")
encode_worker_client = await self._maybe_get_encode_worker_client(
runtime, config
)
# Set up decode worker client for disaggregated mode
decode_worker_client = None
......@@ -269,3 +373,402 @@ class WorkerFactory:
raise
finally:
handler.cleanup()
async def _create_decode_worker(
self,
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
snapshot_engine: Optional[EngineSetupResult] = None,
) -> None:
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
shutdown_endpoints[:] = [
generate_endpoint,
clear_endpoint,
]
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.load_lora"
)
unload_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.unload_lora"
)
list_loras_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.list_loras"
)
shutdown_endpoints.extend(
[
load_lora_endpoint,
unload_lora_endpoint,
list_loras_endpoint,
]
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
fpm_worker_id = str(generate_endpoint.connection_id())
if snapshot_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = snapshot_engine
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
# Factory is created after unpack so component_gauges is available
factory = StatLoggerFactory(
endpoint=generate_endpoint,
component_gauges=component_gauges,
)
else:
# Factory is created without component_gauges; setup_vllm_engine() will
# create the gauges after setup_multiprocess_prometheus() and set them
# on the factory before vLLM calls create_stat_logger().
factory = StatLoggerFactory(
endpoint=generate_endpoint,
)
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = self.setup_vllm_engine(config, factory, fpm_worker_id=fpm_worker_id)
# TODO Hack to get data, move this to registering in TBD
factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
factory.init_publish()
# Currently routing to worker is still controlled by the worker
# as the worker has logic to determine whether remote encode should be
# performed
encode_worker_client = await self._maybe_get_encode_worker_client(
runtime, config
)
handler = DecodeWorkerHandler(
runtime,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
encode_worker_client=encode_worker_client,
)
handler.add_temp_dir(prometheus_temp_dir)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = _consolidator_eps[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publisher for prefix caching if enabled
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = self.setup_kv_event_publisher(
config,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
# Set up forward pass metrics relay (child ZMQ -> event plane).
# In checkpoint mode the engine was created before the runtime, so
# ForwardPassMetrics.worker_id will be empty (relay still works).
fpm_relays = self.setup_fpm_relay(generate_endpoint, vllm_config)
if fpm_relays:
handler.fpm_relays = fpm_relays
self.setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.endpoint_types)
logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
# Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.endpoint_types:
logger.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
)
await self.register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
)
health_check_payload = VllmHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for decode worker")
model_metrics_labels = [
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
]
serve_tasks = [
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
metrics_labels=model_metrics_labels,
),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=model_metrics_labels,
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=model_metrics_labels,
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=model_metrics_labels,
),
]
)
await asyncio.gather(*serve_tasks)
logger.debug("serve_endpoint completed for decode worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
logger.debug("Cleaning up decode worker")
# Cleanup background tasks
handler.cleanup()
async def _create_prefill_worker(
self,
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
snapshot_engine: Optional[EngineSetupResult] = None,
) -> None:
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
fpm_worker_id = str(generate_endpoint.connection_id())
if snapshot_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = snapshot_engine
# TODO: The scheduler in the child process still has worker_id=""
# because the engine was forked before the runtime existed.
# Propagating the new ID to the child requires shared memory or
# a restart of the EngineCore process.
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
else:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = self.setup_vllm_engine(config, fpm_worker_id=fpm_worker_id)
encode_worker_client = await self._maybe_get_encode_worker_client(
runtime, config
)
handler = PrefillWorkerHandler(
runtime,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
encode_worker_client=encode_worker_client,
)
handler.add_temp_dir(prometheus_temp_dir)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = _consolidator_eps[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = self.setup_kv_event_publisher(
config,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
# Set up forward pass metrics relay (child ZMQ -> event plane).
# In checkpoint mode the engine was created before the runtime, so
# ForwardPassMetrics.worker_id will be empty (relay still works).
fpm_relays = self.setup_fpm_relay(generate_endpoint, vllm_config)
if fpm_relays:
handler.fpm_relays = fpm_relays
self.setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
await self.register_vllm_model(
model_input,
ModelType.Prefill,
generate_endpoint,
config,
engine_client,
vllm_config,
)
health_check_payload = VllmPrefillHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for prefill worker")
await asyncio.gather(
# for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, # type: ignore
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
),
)
logger.debug("serve_endpoint completed for prefill worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
logger.debug("Cleaning up prefill worker")
handler.cleanup()
async def _maybe_get_encode_worker_client(
self, runtime: DistributedRuntime, config: Config
) -> Optional[Any]:
"""Helper function to get encode worker client if routing to encoder is enabled."""
if config.route_to_encoder:
encode_worker_client = await runtime.endpoint(
f"{config.namespace}.encoder.generate"
).client()
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers")
return encode_worker_client
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment