Unverified Commit 3a925951 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

refactor: [vLLM] move prefill / decode worker initialization into worker factory (#7367)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent ba2d9684
...@@ -37,6 +37,7 @@ from dynamo.llm import ( ...@@ -37,6 +37,7 @@ from dynamo.llm import (
register_model, register_model,
unregister_model, unregister_model,
) )
from dynamo.runtime import Client
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
...@@ -1341,7 +1342,13 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1341,7 +1342,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None,
): ):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__( super().__init__(
runtime, runtime,
engine, engine,
...@@ -1556,7 +1563,13 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1556,7 +1563,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
encode_worker_client: Client | None = None,
): ):
if encode_worker_client is not None:
raise NotImplementedError(
"'encode_worker_client' is provided which indicates remote "
"multimodal encode is configured, this is not currently supported."
)
super().__init__( super().__init__(
runtime, runtime,
engine, engine,
......
...@@ -17,9 +17,7 @@ from vllm.usage.usage_lib import UsageContext ...@@ -17,9 +17,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import install_signal_handlers from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.prometheus import ( from dynamo.common.utils.prometheus import (
LLMBackendMetrics, LLMBackendMetrics,
...@@ -34,15 +32,14 @@ from dynamo.llm import ( ...@@ -34,15 +32,14 @@ from dynamo.llm import (
fetch_model, fetch_model,
register_model, register_model,
) )
from dynamo.runtime import DistributedRuntime, Endpoint from dynamo.runtime import Endpoint
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.worker_factory import WorkerFactory from dynamo.vllm.worker_factory import WorkerFactory
from . import envs from . import envs
from .args import Config, _uses_dynamo_connector, parse_args from .args import Config, _uses_dynamo_connector, parse_args
from .constants import DisaggregationMode from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker from .handlers import get_dp_range_for_worker
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
from .snapshot import prepare_snapshot_engine from .snapshot import prepare_snapshot_engine
...@@ -142,6 +139,8 @@ async def worker() -> None: ...@@ -142,6 +139,8 @@ async def worker() -> None:
use_kv_events=config.use_kv_events, use_kv_events=config.use_kv_events,
) )
# [gluo FIXME] should be after init() below? 'shutdown_endpoints' are populated
# there
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event) install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
# Route to appropriate initialization based on config flags # Route to appropriate initialization based on config flags
...@@ -151,6 +150,8 @@ async def worker() -> None: ...@@ -151,6 +150,8 @@ async def worker() -> None:
setup_vllm_engine_fn=setup_vllm_engine, setup_vllm_engine_fn=setup_vllm_engine,
setup_kv_event_publisher_fn=setup_kv_event_publisher, setup_kv_event_publisher_fn=setup_kv_event_publisher,
register_vllm_model_fn=register_vllm_model, register_vllm_model_fn=register_vllm_model,
setup_fpm_relay_fn=setup_fpm_relay,
setup_metrics_collection_fn=setup_metrics_collection,
) )
await factory.create( await factory.create(
runtime, runtime,
...@@ -159,23 +160,9 @@ async def worker() -> None: ...@@ -159,23 +160,9 @@ async def worker() -> None:
shutdown_endpoints, shutdown_endpoints,
snapshot_engine=snapshot_engine, snapshot_engine=snapshot_engine,
) )
logger.debug("multimodal worker completed") logger.debug("worker init completed")
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
await init_prefill(
runtime,
config,
shutdown_event,
snapshot_engine=snapshot_engine,
)
logger.debug("init_prefill completed")
else: else:
await init( raise ValueError("Unsupported worker configuration")
runtime,
config,
shutdown_event,
snapshot_engine=snapshot_engine,
)
logger.debug("init completed")
logger.debug("Worker function completed, exiting...") logger.debug("Worker function completed, exiting...")
...@@ -625,378 +612,6 @@ async def register_vllm_model( ...@@ -625,378 +612,6 @@ async def register_vllm_model(
) )
async def init_prefill(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
snapshot_engine: Optional[
tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
fpm_worker_id = str(generate_endpoint.connection_id())
if snapshot_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = snapshot_engine
# TODO: The scheduler in the child process still has worker_id=""
# because the engine was forked before the runtime existed.
# Propagating the new ID to the child requires shared memory or
# a restart of the EngineCore process.
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
else:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = setup_vllm_engine(config, fpm_worker_id=fpm_worker_id)
handler = PrefillWorkerHandler(
runtime,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
)
handler.add_temp_dir(prometheus_temp_dir)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = _consolidator_eps[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
# Set up forward pass metrics relay (child ZMQ -> event plane).
# In checkpoint mode the engine was created before the runtime, so
# ForwardPassMetrics.worker_id will be empty (relay still works).
fpm_relays = setup_fpm_relay(generate_endpoint, vllm_config)
if fpm_relays:
handler.fpm_relays = fpm_relays
setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
await register_vllm_model(
model_input,
ModelType.Prefill,
generate_endpoint,
config,
engine_client,
vllm_config,
)
health_check_payload = VllmPrefillHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for prefill worker")
await asyncio.gather(
# for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, # type: ignore
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
),
)
logger.debug("serve_endpoint completed for prefill worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
logger.debug("Cleaning up prefill worker")
handler.cleanup()
async def init(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
snapshot_engine: Optional[
tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
"""
Instantiate and serve
"""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
shutdown_endpoints[:] = [
generate_endpoint,
clear_endpoint,
]
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.load_lora"
)
unload_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.unload_lora"
)
list_loras_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.list_loras"
)
shutdown_endpoints.extend(
[
load_lora_endpoint,
unload_lora_endpoint,
list_loras_endpoint,
]
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
fpm_worker_id = str(generate_endpoint.connection_id())
if snapshot_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = snapshot_engine
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
# Factory is created after unpack so component_gauges is available
factory = StatLoggerFactory(
endpoint=generate_endpoint,
component_gauges=component_gauges,
)
else:
# Factory is created without component_gauges; setup_vllm_engine() will
# create the gauges after setup_multiprocess_prometheus() and set them
# on the factory before vLLM calls create_stat_logger().
factory = StatLoggerFactory(
endpoint=generate_endpoint,
)
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = setup_vllm_engine(config, factory, fpm_worker_id=fpm_worker_id)
# TODO Hack to get data, move this to registering in TBD
factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
factory.init_publish()
handler = DecodeWorkerHandler(
runtime,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
)
handler.add_temp_dir(prometheus_temp_dir)
# Check if kv event consolidator is enabled (port was allocated in setup_vllm_engine)
consolidator_enabled = False
consolidator_port = None
_consolidator_eps = vllm_config.additional_config.get("consolidator_endpoints")
if _consolidator_eps:
# Extract connect endpoint (third element) for clients to subscribe
# consolidator_endpoints = (vllm_endpoint, bind_endpoint, connect_endpoint)
consolidator_output_endpoint = _consolidator_eps[2]
consolidator_port = int(consolidator_output_endpoint.split(":")[-1])
consolidator_enabled = True
# Set up KV event publisher for prefix caching if enabled
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
consolidator_port=consolidator_port,
)
if kv_publishers:
handler.kv_publishers = kv_publishers
# Set up forward pass metrics relay (child ZMQ -> event plane).
# In checkpoint mode the engine was created before the runtime, so
# ForwardPassMetrics.worker_id will be empty (relay still works).
fpm_relays = setup_fpm_relay(generate_endpoint, vllm_config)
if fpm_relays:
handler.fpm_relays = fpm_relays
setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake_up engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.endpoint_types)
logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
# Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.endpoint_types:
logger.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
)
await register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
)
health_check_payload = VllmHealthCheckPayload(
engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict()
try:
logger.debug("Starting serve_endpoint for decode worker")
model_metrics_labels = [
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
]
serve_tasks = [
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
metrics_labels=model_metrics_labels,
),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=model_metrics_labels,
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=model_metrics_labels,
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=model_metrics_labels,
),
]
)
await asyncio.gather(*serve_tasks)
logger.debug("serve_endpoint completed for decode worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
logger.debug("Cleaning up decode worker")
# Cleanup background tasks
handler.cleanup()
def get_engine_cache_info(engine: AsyncLLM) -> dict[str, Any]: def get_engine_cache_info(engine: AsyncLLM) -> dict[str, Any]:
"""Retrieve cache configuration information from [`AsyncLLM`] engine.""" """Retrieve cache configuration information from [`AsyncLLM`] engine."""
......
...@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, Mock ...@@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, Mock
import pytest import pytest
from dynamo.vllm.constants import DisaggregationMode
from dynamo.vllm.worker_factory import EngineSetupResult, WorkerFactory from dynamo.vllm.worker_factory import EngineSetupResult, WorkerFactory
pytestmark = [ pytestmark = [
...@@ -25,7 +26,8 @@ def _make_config(**overrides) -> Mock: ...@@ -25,7 +26,8 @@ def _make_config(**overrides) -> Mock:
"multimodal_worker": False, "multimodal_worker": False,
"multimodal_decode_worker": False, "multimodal_decode_worker": False,
"omni": False, "omni": False,
"is_prefill_worker": False, "route_to_encoder": False,
"disaggregation_mode": DisaggregationMode.AGGREGATED,
} }
defaults.update(overrides) defaults.update(overrides)
return Mock(**defaults) return Mock(**defaults)
...@@ -34,31 +36,55 @@ def _make_config(**overrides) -> Mock: ...@@ -34,31 +36,55 @@ def _make_config(**overrides) -> Mock:
class TestHandles: class TestHandles:
"""Test WorkerFactory.handles() config detection.""" """Test WorkerFactory.handles() config detection."""
def test_multimodal_encode_worker(self) -> None: # Legacy worker config
config = _make_config(multimodal_encode_worker=True) @pytest.mark.parametrize("route_to_encode", [True, False])
def test_multimodal_encode_worker(self, route_to_encode: bool) -> None:
# 'route_to_encoder' can be passed, the worker creation may ignore it.
config = _make_config(
multimodal_encode_worker=True, route_to_encoder=route_to_encode
)
assert WorkerFactory.handles(config) assert WorkerFactory.handles(config)
def test_multimodal_worker(self) -> None: @pytest.mark.parametrize("route_to_encode", [True, False])
config = _make_config(multimodal_worker=True) def test_multimodal_worker(self, route_to_encode: bool) -> None:
config = _make_config(multimodal_worker=True, route_to_encoder=route_to_encode)
assert WorkerFactory.handles(config) assert WorkerFactory.handles(config)
def test_multimodal_decode_worker(self) -> None: @pytest.mark.parametrize("route_to_encode", [True, False])
config = _make_config(multimodal_decode_worker=True) def test_multimodal_decode_worker(self, route_to_encode: bool) -> None:
config = _make_config(
multimodal_decode_worker=True, route_to_encoder=route_to_encode
)
assert WorkerFactory.handles(config) assert WorkerFactory.handles(config)
def test_no_multimodal_flags(self) -> None: # Tests for no standalone encode worker setting
config = _make_config() @pytest.mark.parametrize("route_to_encode", [True, False])
assert not WorkerFactory.handles(config) def test_no_multimodal_flags(self, route_to_encode: bool) -> None:
config = _make_config(route_to_encoder=route_to_encode)
assert WorkerFactory.handles(config)
def test_omni_not_handled(self) -> None: @pytest.mark.parametrize("route_to_encode", [True, False])
config = _make_config(omni=True) def test_prefill(self, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.PREFILL,
route_to_encoder=route_to_encode,
)
# [gluo NOTE] due to current limitation, see 'WorkerFactory._validate_config()'.
if route_to_encode:
assert not WorkerFactory.handles(config) assert not WorkerFactory.handles(config)
else:
assert WorkerFactory.handles(config)
def test_prefill_only_not_handled(self) -> None: @pytest.mark.parametrize("route_to_encode", [True, False])
config = _make_config(is_prefill_worker=True) def test_decode(self, route_to_encode: bool) -> None:
assert not WorkerFactory.handles(config) config = _make_config(
disaggregation_mode=DisaggregationMode.DECODE,
route_to_encoder=route_to_encode,
)
assert WorkerFactory.handles(config)
@pytest.mark.asyncio
class TestCreate: class TestCreate:
"""Test WorkerFactory.create() routing.""" """Test WorkerFactory.create() routing."""
...@@ -68,41 +94,90 @@ class TestCreate: ...@@ -68,41 +94,90 @@ class TestCreate:
setup_vllm_engine_fn=Mock(), setup_vllm_engine_fn=Mock(),
setup_kv_event_publisher_fn=Mock(), setup_kv_event_publisher_fn=Mock(),
register_vllm_model_fn=AsyncMock(), register_vllm_model_fn=AsyncMock(),
setup_fpm_relay_fn=Mock(),
setup_metrics_collection_fn=Mock(),
) )
factory._create_multimodal_encode_worker = AsyncMock() # type: ignore[assignment] factory._create_multimodal_encode_worker = AsyncMock() # type: ignore[assignment]
factory._create_multimodal_worker = AsyncMock() # type: ignore[assignment] factory._create_multimodal_worker = AsyncMock() # type: ignore[assignment]
factory._create_prefill_worker = AsyncMock() # type: ignore[assignment]
factory._create_decode_worker = AsyncMock() # type: ignore[assignment]
return factory return factory
@pytest.mark.asyncio # Tests for non-legacy worker config, 'route_to_encode' is worker internal config
async def test_routes_to_multimodal_encode(self, factory: WorkerFactory) -> None: # so either case should hit creation function.
config = _make_config(multimodal_encode_worker=True) @pytest.mark.parametrize("route_to_encode", [True, False])
async def test_aggregated(
self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(route_to_encoder=route_to_encode)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_decode_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_prefill(self, factory: WorkerFactory, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.PREFILL,
route_to_encoder=route_to_encode,
)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_prefill_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_decode(self, factory: WorkerFactory, route_to_encode: bool) -> None:
config = _make_config(
disaggregation_mode=DisaggregationMode.DECODE,
route_to_encoder=route_to_encode,
)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, [])
factory._create_decode_worker.assert_called_once() # type: ignore[union-attr]
# Tests with legacy worker config.
@pytest.mark.parametrize("route_to_encode", [True, False])
async def test_routes_to_multimodal_encode(
self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(
multimodal_encode_worker=True, route_to_encoder=route_to_encode
)
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, []) await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_encode_worker.assert_called_once() # type: ignore[union-attr] factory._create_multimodal_encode_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio @pytest.mark.parametrize("route_to_encode", [True, False])
async def test_routes_to_multimodal_worker(self, factory: WorkerFactory) -> None: async def test_routes_to_multimodal_worker(
config = _make_config(multimodal_worker=True) self, factory: WorkerFactory, route_to_encode: bool
) -> None:
config = _make_config(multimodal_worker=True, route_to_encoder=route_to_encode)
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, []) await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr] factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio @pytest.mark.parametrize("route_to_encode", [True, False])
async def test_routes_multimodal_decode_worker( async def test_routes_multimodal_decode_worker(
self, factory: WorkerFactory self, factory: WorkerFactory, route_to_encode: bool
) -> None: ) -> None:
config = _make_config(multimodal_decode_worker=True) config = _make_config(
multimodal_decode_worker=True, route_to_encoder=route_to_encode
)
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event, []) await factory.create(Mock(), config, shutdown_event, [])
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr] factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_passes_snapshot_engine(self, factory: WorkerFactory) -> None: async def test_passes_snapshot_engine(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_worker=True) config = _make_config(multimodal_worker=True)
runtime = Mock() runtime = Mock()
...@@ -131,9 +206,3 @@ class TestCreate: ...@@ -131,9 +206,3 @@ class TestCreate:
shutdown_endpoints, shutdown_endpoints,
snapshot_engine=snapshot_engine, snapshot_engine=snapshot_engine,
) )
@pytest.mark.asyncio
async def test_raises_when_no_multimodal_flag(self, factory: WorkerFactory) -> None:
config = _make_config()
with pytest.raises(ValueError, match="no multimodal worker type set"):
await factory.create(Mock(), config, asyncio.Event(), [])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment