worker_factory.py 9.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Worker initialization factory for vLLM workers."""

import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Any, Optional

from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime

from .args import Config
from .multimodal_handlers import (
    EncodeWorkerHandler,
    MultimodalDecodeWorkerHandler,
    MultimodalPDWorkerHandler,
)

logger = logging.getLogger(__name__)

24
25
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir, component_gauges)
EngineSetupResult = tuple[Any, Any, Any, Any, Any]
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

SetupVllmEngineFn = Callable[..., EngineSetupResult]
SetupKvEventPublisherFn = Callable[..., Optional[Any]]
RegisterVllmModelFn = Callable[..., Awaitable[None]]


class WorkerFactory:
    """Factory for creating and initializing multimodal vLLM workers."""

    def __init__(
        self,
        setup_vllm_engine_fn: SetupVllmEngineFn,
        setup_kv_event_publisher_fn: SetupKvEventPublisherFn,
        register_vllm_model_fn: RegisterVllmModelFn,
    ):
        self.setup_vllm_engine = setup_vllm_engine_fn
        self.setup_kv_event_publisher = setup_kv_event_publisher_fn
        self.register_vllm_model = register_vllm_model_fn

    @staticmethod
    def handles(config: Config) -> bool:
        """Return True if this factory handles the given config."""
        return bool(
            config.multimodal_encode_worker
            or config.multimodal_worker
            or config.multimodal_decode_worker
        )

    async def create(
        self,
        runtime: DistributedRuntime,
        config: Config,
        shutdown_event: asyncio.Event,
59
        shutdown_endpoints: list,
60
61
62
        pre_created_engine: Optional[EngineSetupResult] = None,
    ) -> None:
        """Create the appropriate multimodal worker based on config flags."""
63

64
        if config.multimodal_encode_worker:
65
66
67
            await self._create_multimodal_encode_worker(
                runtime, config, shutdown_event, shutdown_endpoints
            )
68
69
        elif config.multimodal_worker or config.multimodal_decode_worker:
            await self._create_multimodal_worker(
70
71
72
73
74
                runtime,
                config,
                shutdown_event,
                shutdown_endpoints,
                pre_created_engine=pre_created_engine,
75
76
77
78
79
80
81
82
83
84
85
            )
        else:
            raise ValueError(
                "WorkerFactory.create() called but no multimodal worker type set in config"
            )

    async def _create_multimodal_worker(
        self,
        runtime: DistributedRuntime,
        config: Config,
        shutdown_event: asyncio.Event,
86
        shutdown_endpoints: list,  # mutated in place
87
88
89
90
91
92
93
94
95
96
97
98
99
        pre_created_engine: Optional[EngineSetupResult] = None,
    ) -> None:
        """
        Initialize multimodal worker component.

        Supports:
        - --multimodal-worker: PD worker that may receive embeddings from encoder
        - --multimodal-decode-worker: Decode-only worker

        Modes:
        - Aggregated (P+D): Prefill and decode on same worker
        - Disaggregated (P→D): Prefill forwards to separate decode worker
        """
100
101
102
        generate_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.{config.endpoint}"
        )
103
104
105
        clear_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.clear_kv_blocks"
        )
106
        shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
107

108
109
        lora_enabled = config.engine_args.enable_lora
        if lora_enabled:
110
111
112
113
114
115
116
117
118
            load_lora_endpoint = runtime.endpoint(
                f"{config.namespace}.{config.component}.load_lora"
            )
            unload_lora_endpoint = runtime.endpoint(
                f"{config.namespace}.{config.component}.unload_lora"
            )
            list_loras_endpoint = runtime.endpoint(
                f"{config.namespace}.{config.component}.list_loras"
            )
119
120
121
            shutdown_endpoints.extend(
                [load_lora_endpoint, unload_lora_endpoint, list_loras_endpoint]
            )
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        # Use pre-created engine if provided (checkpoint mode), otherwise create new
        if pre_created_engine is not None:
            (
                engine_client,
                vllm_config,
                _default_sampling_params,
                prometheus_temp_dir,
                _component_gauges,
            ) = pre_created_engine
        else:
            (
                engine_client,
                vllm_config,
                _default_sampling_params,
                prometheus_temp_dir,
                _component_gauges,
            ) = self.setup_vllm_engine(config)

        # Set up encode worker client when routing to encoder is enabled
        encode_worker_client = None
        if config.route_to_encoder:
143
144
145
            encode_worker_client = await runtime.endpoint(
                f"{config.namespace}.encoder.generate"
            ).client()
146
147
148
149
150
151
152
            logger.info("Waiting for Encoder Worker Instances ...")
            await encode_worker_client.wait_for_instances()
            logger.info("Connected to encoder workers")

        # Set up decode worker client for disaggregated mode
        decode_worker_client = None
        if config.is_prefill_worker:
153
154
155
            decode_worker_client = await runtime.endpoint(
                f"{config.namespace}.decoder.generate"
            ).client()
156
157
158
159
160
161
            await decode_worker_client.wait_for_instances()
            logger.info("Connected to decode worker for disaggregated mode")

        # Choose handler based on worker type
        if config.multimodal_decode_worker:
            handler = MultimodalDecodeWorkerHandler(
162
163
164
165
166
                runtime,
                engine_client,
                config,
                shutdown_event,
                generate_endpoint=generate_endpoint,
167
168
169
170
171
172
173
174
175
            )
        else:
            handler = MultimodalPDWorkerHandler(
                runtime,
                engine_client,
                config,
                encode_worker_client,
                decode_worker_client,
                shutdown_event,
176
                generate_endpoint=generate_endpoint,
177
178
179
180
181
182
183
            )
        handler.add_temp_dir(prometheus_temp_dir)

        await handler.async_init(runtime)

        # Set up KV event publisher for prefix caching if enabled
        kv_publisher = self.setup_kv_event_publisher(
184
            config, generate_endpoint, vllm_config
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        )
        if kv_publisher:
            handler.kv_publisher = kv_publisher

        # Register model with the frontend so it can route requests
        model_type = parse_endpoint_types(config.endpoint_types)
        model_input = (
            ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
        )
        await self.register_vllm_model(
            model_input,
            model_type,
            generate_endpoint,
            config,
            engine_client,
            vllm_config,
        )

        metrics_labels = [("model", config.served_model_name or config.model)]
        try:
205
            serve_tasks = [
206
207
208
209
210
211
212
213
                generate_endpoint.serve_endpoint(
                    handler.generate,
                    metrics_labels=metrics_labels,
                ),
                clear_endpoint.serve_endpoint(
                    handler.clear_kv_blocks,
                    metrics_labels=metrics_labels,
                ),
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            ]

            if lora_enabled:
                serve_tasks.extend(
                    [
                        load_lora_endpoint.serve_endpoint(
                            handler.load_lora,
                            metrics_labels=metrics_labels,
                        ),
                        unload_lora_endpoint.serve_endpoint(
                            handler.unload_lora,
                            metrics_labels=metrics_labels,
                        ),
                        list_loras_endpoint.serve_endpoint(
                            handler.list_loras,
                            metrics_labels=metrics_labels,
                        ),
                    ]
                )

            await asyncio.gather(*serve_tasks)
235
236
237
238
239
240
241
242
243
244
245
        except Exception as e:
            logger.error(f"Failed to serve endpoints: {e}")
            raise
        finally:
            handler.cleanup()

    async def _create_multimodal_encode_worker(
        self,
        runtime: DistributedRuntime,
        config: Config,
        shutdown_event: asyncio.Event,
246
        shutdown_endpoints: list,  # mutated in place
247
248
    ) -> None:
        """Initialize standalone multimodal encode worker."""
249
250
251
        generate_endpoint = runtime.endpoint(
            f"{config.namespace}.{config.component}.{config.endpoint}"
        )
252
        shutdown_endpoints[:] = [generate_endpoint]
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

        handler = EncodeWorkerHandler(config.engine_args)
        await handler.async_init(runtime)
        logger.info("Starting to serve the encode worker endpoint...")

        try:
            await asyncio.gather(
                generate_endpoint.serve_endpoint(
                    handler.generate, metrics_labels=[("model", config.model)]
                ),
            )
        except Exception as e:
            logger.error(f"Failed to serve encode worker endpoint: {e}")
            raise
        finally:
            handler.cleanup()