handlers.py 33 KB
Newer Older
Alec's avatar
Alec committed
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
6
import os
7
import tempfile
Alec's avatar
Alec committed
8
from abc import ABC, abstractmethod
9
from contextlib import asynccontextmanager
10
from typing import Any, AsyncGenerator, Dict, Final
Alec's avatar
Alec committed
11
12

from vllm.inputs import TokensPrompt
13
from vllm.lora.request import LoRARequest
14
from vllm.outputs import RequestOutput
Alec's avatar
Alec committed
15
from vllm.sampling_params import SamplingParams
16
from vllm.v1.engine.exceptions import EngineDeadError
Alec's avatar
Alec committed
17

18
19
20
21
22
23
24
25
from dynamo.llm import (
    ModelInput,
    ModelType,
    ZmqKvEventPublisher,
    lora_name_to_id,
    register_llm,
    unregister_llm,
)
Alec's avatar
Alec committed
26
27
from dynamo.runtime.logging import configure_dynamo_logging

28
from .engine_monitor import VllmEngineMonitor
29
30
31
32
33
34
35
from .multimodal_utils.image_loader import ImageLoader

# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url"
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
Alec's avatar
Alec committed
36

Alec's avatar
Alec committed
37
38
39
configure_dynamo_logging()
logger = logging.getLogger(__name__)

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# LoRAManager singleton - initialized lazily when DYN_LORA_ENABLED is set
# None = not yet initialized, False = disabled/failed, LoRAManager = initialized
_lora_manager = None


def get_lora_manager():
    """Get the LoRAManager singleton, initializing it on first call if enabled."""
    global _lora_manager

    if _lora_manager is not None:
        return _lora_manager

    if os.environ.get("DYN_LORA_ENABLED", "").lower() in ("true", "1", "yes"):
        try:
            from dynamo.common.lora import LoRAManager

            _lora_manager = LoRAManager()
            logger.info("LoRAManager initialized successfully")
            return _lora_manager
        except Exception as e:
            logger.warning(
                f"Failed to initialize LoRAManager: {e}. URI-based LoRA loading will be disabled."
            )

    return None

Alec's avatar
Alec committed
66

67
def build_sampling_params(
68
69
70
    request: Dict[str, Any],
    default_sampling_params: Dict[str, Any],
    model_max_len: int | None = None,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
) -> SamplingParams:
    """
    Build SamplingParams from a PreprocessedRequest.

    Args:
        request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
        default_sampling_params: Default sampling parameters to initialize with

    Returns:
        SamplingParams configured from the request
    """
    sampling_params = SamplingParams(**default_sampling_params)
    sampling_params.detokenize = False

    # Apply sampling_options
    for key, value in request["sampling_options"].items():
        if value is not None and hasattr(sampling_params, key):
            setattr(sampling_params, key, value)

    # Apply stop_conditions
    for key, value in request["stop_conditions"].items():
        if value is not None and hasattr(sampling_params, key):
93
94
95
            # Do not add stop key to sampling params - dynamo handles stop conditions directly
            if key == "stop":
                continue
96
97
            setattr(sampling_params, key, value)

98
99
100
101
102
103
104
105
106
    # If max_tokens wasn't provided (None or missing), compute a dynamic default
    provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
    token_ids = request.get("token_ids", [])
    input_length = len(token_ids)
    if model_max_len is not None and (provided_max_tokens is None):
        # Ensure at least 1 token generation by default when possible
        dynamic_default = max(1, model_max_len - input_length)
        sampling_params.max_tokens = dynamic_default

107
108
109
    return sampling_params


Alec's avatar
Alec committed
110
111
112
113
114
class BaseWorkerHandler(ABC):
    """
    Request handler for the generate and clear_kv_blocks endpoints.
    """

115
116
117
118
119
120
121
    def __init__(
        self,
        runtime,
        component,
        engine,
        default_sampling_params,
        model_max_len: int | None = None,
122
        enable_multimodal: bool = False,
123
124
        generate_endpoint=None,
        config=None,
125
    ):
126
        self.runtime = runtime
Alec's avatar
Alec committed
127
128
129
        self.component = component
        self.engine_client = engine
        self.default_sampling_params = default_sampling_params
130
        self.kv_publishers: list[ZmqKvEventPublisher] | None = None
131
132
        self.generate_endpoint = generate_endpoint
        self.config = config
133
        self.engine_monitor = VllmEngineMonitor(runtime, engine)
134
        self.image_loader = ImageLoader()
135
        self.temp_dirs: list[tempfile.TemporaryDirectory] = []
136
        self.model_max_len = model_max_len
137
        self.enable_multimodal = enable_multimodal
138
139
140
        # LoRA tracking
        self.lora_id_for_name: dict[str, int] = {}
        self.lora_name_to_path: dict[str, str] = {}
Alec's avatar
Alec committed
141
142

    @abstractmethod
143
    async def generate(self, request, context) -> AsyncGenerator[dict, None]:
Alec's avatar
Alec committed
144
145
        raise NotImplementedError

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    async def _monitor_abort(self, context, request_id, is_prefill):
        """Background task that monitors for context cancellation and aborts the request."""
        try:
            await context.async_killed_or_stopped()
            # If we reach here, the context was stopped or killed
            await self.engine_client.abort(request_id)
            logger.debug(
                f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
            )
        except asyncio.CancelledError:
            # Task was cancelled, normal cleanup if not aborted
            pass
        except Exception as e:
            logger.error(f"Error in abort monitor for request {request_id}: {e}")

    @asynccontextmanager
    async def _abort_monitor(self, context, request_id, is_prefill=False):
        """Context manager that creates and automatically cleans up an abort monitoring task."""
        task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
        try:
            yield task
        finally:
            # Cancel the abort monitoring task when exiting the context
            if not task.done():
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass

Alec's avatar
Alec committed
176
177
178
179
180
181
182
    async def clear_kv_blocks(self, request=None):
        try:
            await self.engine_client.reset_prefix_cache()
            yield {"status": "success", "message": "KV cache cleared"}
        except Exception as e:
            yield {"status": "error", "message": str(e)}

183
184
185
186
187
    def add_temp_dir(self, temp_dir: tempfile.TemporaryDirectory) -> None:
        """Add a temporary directory to be cleaned up later."""
        if temp_dir is not None:
            self.temp_dirs.append(temp_dir)

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    async def load_lora(self, request=None):
        """
        Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.

        Request format:
        {
            "lora_name": str,
            "source": {
                "uri": str  # e.g., "s3://bucket/path" or "file:///path"
            }
        }
        """
        try:
            if request is None:
                yield {
                    "status": "error",
                    "message": "Request is required with 'lora_name' and 'source.uri'",
                }
                return

            lora_name = request.get("lora_name")
            if not lora_name:
                yield {
                    "status": "error",
                    "message": "'lora_name' is required in request",
                }
                return

            # Debug: Log the incoming request
            logger.debug(f"load_lora request keys: {list(request.keys())}")
            logger.debug(f"load_lora request: {request}")

            # Check for URI-based API format (source.uri)
            source = request.get("source")
            if not source or not isinstance(source, dict):
                yield {
                    "status": "error",
                    "message": "'source' object is required in request",
                }
                return

            lora_uri = source.get("uri")
            if not lora_uri:
                yield {
                    "status": "error",
                    "message": "'source.uri' is required in request",
                }
                return

            # Use LoRAManager to download from URI
            lora_manager = get_lora_manager()
            if lora_manager is None:
                yield {
                    "status": "error",
                    "message": "LoRAManager not initialized. Set DYN_LORA_ENABLED=true to enable URI-based LoRA loading.",
                }
                return

            logger.info(f"Downloading LoRA adapter: {lora_name} from {lora_uri}")
            download_result = await lora_manager.download_lora(lora_uri)

            if download_result["status"] != "success":
                yield {
                    "status": "error",
                    "message": f"Failed to download LoRA: {download_result.get('message', 'Unknown error')}",
                }
                return

            lora_path = download_result["local_path"]
            logger.debug(f"LoRA downloaded to: {lora_path}")

            # Generate deterministic ID from lora_name before using it
            lora_id = lora_name_to_id(lora_name)

            # Add the LoRA to the engine
            await self.engine_client.add_lora(
                LoRARequest(
                    lora_name=lora_name, lora_int_id=lora_id, lora_path=lora_path
                )
            )

            # Track the LoRA
            self.lora_id_for_name[lora_name] = lora_id
            self.lora_name_to_path[lora_name] = lora_path
            logger.info(
                f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
            )

            # Publish LoRA as a ModelDeploymentCard with format:
            # v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
            # This allows the frontend to discover it and route correctly to the worker instance

            if self.generate_endpoint is not None and self.config is not None:
                logger.debug(
                    f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
                )
                try:
                    logger.debug(f"Publishing LoRA '{lora_name}' ModelDeploymentCard")

                    # Mark this as a LoRA in user_data
                    user_data = {
                        "lora_adapter": True,
                        "lora_id": lora_id,
                    }

                    # Publish with format: v1/mdc/dynamo/backend/generate/{instance_id}/{lora_slug}
                    await register_llm(
                        model_input=ModelInput.Tokens,
                        model_type=ModelType.Chat | ModelType.Completions,
                        endpoint=self.generate_endpoint,
                        model_path=self.config.model,
                        kv_cache_block_size=self.config.engine_args.block_size,
                        user_data=user_data,
                        lora_name=lora_name,
                        base_model_path=self.config.model,
                    )
                    logger.info(
                        f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
                    )
                except Exception as e:
                    import traceback

                    logger.error(
                        f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
                    )
                    logger.debug(f"Traceback: {traceback.format_exc()}")

                    # Rollback: remove the LoRA from the engine to maintain consistency
                    try:
                        logger.debug(
                            f"Rolling back: removing LoRA '{lora_name}' from engine"
                        )
                        await self.engine_client.remove_lora(lora_id)
                        # Remove from tracking dictionaries
                        if lora_name in self.lora_id_for_name:
                            del self.lora_id_for_name[lora_name]
                        if lora_name in self.lora_name_to_path:
                            del self.lora_name_to_path[lora_name]
                        logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
                    except Exception as rollback_error:
                        logger.error(
                            f"Failed to rollback LoRA {lora_name}: {rollback_error}"
                        )

                    # Return error status since registration failed
                    yield {
                        "status": "error",
                        "message": f"Failed to register LoRA '{lora_name}' in discovery registry: {str(e)}",
                        "lora_name": lora_name,
                    }
                    return
            else:
                logger.debug(
                    f"Cannot publish LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}, config={self.config}"
                )

            yield {
                "status": "success",
                "message": f"LoRA adapter '{lora_name}' loaded successfully",
                "lora_name": lora_name,
                "lora_id": lora_id,
            }
        except Exception as e:
            logger.error(f"Failed to load LoRA adapter: {e}")
            yield {"status": "error", "message": str(e)}

    async def unload_lora(self, request=None):
        """
        Unload a LoRA adapter dynamically from the vLLM's AsyncLLM engine.
        Expected request format:
        {
            "lora_name": str,
        }
        """
        try:
            if request is None:
                yield {
                    "status": "error",
                    "message": "Request is required with 'lora_name' field",
                }
                return
            lora_name = request.get("lora_name")
            if not lora_name:
                yield {
                    "status": "error",
                    "message": "'lora_name' is required in request",
                }
                return

            # Check if the LoRA exists
            if lora_name not in self.lora_id_for_name:
                yield {
                    "status": "error",
                    "message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}",
                }
                return

            logger.debug(f"Unloading LoRA adapter: {lora_name}")
            lora_id = self.lora_id_for_name[lora_name]
            lora_path = self.lora_name_to_path.get(lora_name)

            await self.engine_client.remove_lora(lora_id)

            # Remove from tracking dictionaries
            del self.lora_id_for_name[lora_name]
            if lora_name in self.lora_name_to_path:
                del self.lora_name_to_path[lora_name]

            # Unregister the LoRA model from the model registry (outside lock)
            if self.generate_endpoint is not None:
                logger.debug(f"Unregistering LoRA '{lora_name}' ModelDeploymentCard")
                try:
                    await unregister_llm(
                        endpoint=self.generate_endpoint,
                        lora_name=lora_name,
                    )
                    logger.info(
                        f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
                    )
                except Exception as e:
                    import traceback

                    logger.error(
                        f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
                    )
                    logger.debug(f"Traceback: {traceback.format_exc()}")

                    # Rollback: re-add the LoRA to the engine to maintain consistency
                    try:
                        logger.debug(
                            f"Rolling back: re-adding LoRA '{lora_name}' to engine"
                        )
                        await self.engine_client.add_lora(
                            LoRARequest(
                                lora_name=lora_name,
                                lora_int_id=lora_id,
                                lora_path=lora_path,
                            )
                        )
                        # Re-add to tracking dictionaries
                        self.lora_id_for_name[lora_name] = lora_id
                        if lora_path:
                            self.lora_name_to_path[lora_name] = lora_path
                        logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
                    except Exception as rollback_error:
                        logger.error(
                            f"Failed to rollback LoRA {lora_name}: {rollback_error}"
                        )

                    # Return error status since unregistration failed
                    yield {
                        "status": "error",
                        "message": f"Failed to unregister LoRA '{lora_name}' from discovery registry: {str(e)}",
                        "lora_name": lora_name,
                    }
                    return
            else:
                logger.debug(
                    f"Cannot unregister LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}"
                )

            logger.info(
                f"Successfully unloaded LoRA adapter: {lora_name} with ID {lora_id}"
            )
            yield {
                "status": "success",
                "message": f"LoRA adapter '{lora_name}' unloaded successfully",
                "lora_name": lora_name,
                "lora_id": lora_id,
            }
        except Exception as e:
            logger.error(f"Failed to unload LoRA adapter: {e}")
            yield {"status": "error", "message": str(e)}

    async def list_loras(self, request=None):
        """
        List all loaded LoRA adapters.
        Returns a dictionary of lora_name -> lora_id mappings.
        """
        try:
            loras = dict(self.lora_id_for_name)
            yield {
                "status": "success",
                "loras": loras,
                "count": len(loras),
            }
        except Exception as e:
            logger.error(f"Failed to list LoRA adapters: {e}")
            yield {"status": "error", "message": str(e)}

Alec's avatar
Alec committed
478
    def cleanup(self):
479
480
481
482
483
484
        """Clean up resources including temporary directories."""
        for temp_dir in self.temp_dirs:
            try:
                temp_dir.cleanup()
            except Exception as e:
                logger.warning(f"Failed to clean up temp directory: {e}")
Alec's avatar
Alec committed
485

486
487
488
489
490
491
492
493
494
    async def _extract_multimodal_data(
        self, request: Dict[str, Any]
    ) -> Dict[str, Any] | None:
        """
        Extract and decode multimodal data from PreprocessedRequest.
        """
        if "multi_modal_data" not in request or request["multi_modal_data"] is None:
            return None

495
496
497
498
499
500
501
        # Security check: reject multimodal data if not explicitly enabled
        if not self.enable_multimodal:
            raise ValueError(
                "Received multimodal data but multimodal processing is not enabled. "
                "Use --enable-multimodal flag to enable multimodal processing."
            )

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        mm_map = request["multi_modal_data"]
        vllm_mm_data = {}

        # Process image_url entries
        images = []
        for item in mm_map.get(IMAGE_URL_KEY, []):
            if isinstance(item, dict) and URL_VARIANT_KEY in item:
                url = item[URL_VARIANT_KEY]
                try:
                    # ImageLoader supports both data: and http(s): URLs with caching
                    image = await self.image_loader.load_image(url)
                    images.append(image)
                    logger.debug(f"Loaded image from URL: {url[:80]}...")
                except Exception:
                    logger.exception(f"Failed to load image from {url[:80]}...")
                    raise
            elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
                # Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
                # Will contain NIXL metadata for direct memory access
                # TODO: Implement NIXL read when PRs merge
                logger.warning(
                    "Decoded multimodal data not yet supported in standard worker"
                )

        if images:
            # vLLM expects single image or list
            vllm_mm_data["image"] = images[0] if len(images) == 1 else images
            logger.debug(f"Extracted {len(images)} image(s) for multimodal processing")

        # Handle video_url entries (future expansion)
        if VIDEO_URL_KEY in mm_map:
            logger.warning("Video multimodal data not yet supported in standard worker")

        return vllm_mm_data if vllm_mm_data else None

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    @staticmethod
    def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]:
        return {
            "prompt_tokens": (
                len(request_output.prompt_token_ids)
                if request_output.prompt_token_ids
                else None
            ),
            "completion_tokens": len(request_output.outputs[0].token_ids),
            "total_tokens": (
                len(request_output.prompt_token_ids)
                + len(request_output.outputs[0].token_ids)
                if request_output.prompt_token_ids
                else None
            ),
            "prompt_tokens_details": (
                {"cached_tokens": request_output.num_cached_tokens}
                if request_output.num_cached_tokens
                else None
            ),
        }

Yan Ru Pei's avatar
Yan Ru Pei committed
559
    async def generate_tokens(
560
561
562
563
564
565
        self,
        prompt,
        sampling_params,
        request_id,
        data_parallel_rank=None,
        lora_request=None,
Yan Ru Pei's avatar
Yan Ru Pei committed
566
    ):
567
        try:
568
569
570
571
572
573
574
575
576
577
578
            # Log LoRA usage for this generation (debug level to avoid log spam)
            if lora_request:
                logger.debug(
                    f"Starting token generation for request {request_id} with LoRA: "
                    f"{lora_request.lora_name} (ID: {lora_request.lora_int_id})"
                )
            else:
                logger.debug(
                    f"Starting token generation for request {request_id} (no LoRA)"
                )

Yan Ru Pei's avatar
Yan Ru Pei committed
579
580
581
582
            gen = self.engine_client.generate(
                prompt,
                sampling_params,
                request_id,
583
                lora_request=lora_request,
Yan Ru Pei's avatar
Yan Ru Pei committed
584
585
                data_parallel_rank=data_parallel_rank,
            )
586

587
588
589
590
591
592
            num_output_tokens_so_far = 0
            try:
                async for res in gen:
                    # res is vllm's RequestOutput

                    if not res.outputs:
593
594
595
596
597
                        if lora_request:
                            logger.debug(
                                f"Request {request_id} with LoRA {lora_request.lora_name} "
                                "returned no outputs"
                            )
598
599
600
601
602
603
604
605
                        yield {"finish_reason": "error", "token_ids": []}
                        break

                    output = res.outputs[0]
                    next_total_toks = len(output.token_ids)
                    out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
                    if output.finish_reason:
                        out["finish_reason"] = output.finish_reason
606
607
608
609
610
                        out[
                            "completion_usage"
                        ] = BaseWorkerHandler._build_completion_usage(
                            request_output=res
                        )
611
612
613
614
615
616
617
618
619
620
621
622
                        # Log completion with LoRA info (debug level to avoid log spam)
                        if lora_request:
                            logger.debug(
                                f"Completed token generation for request {request_id} with LoRA "
                                f"{lora_request.lora_name}: {next_total_toks} output tokens, "
                                f"finish_reason={output.finish_reason}"
                            )
                        else:
                            logger.debug(
                                f"Completed token generation for request {request_id}: "
                                f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
                            )
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
                    if output.stop_reason:
                        out["stop_reason"] = output.stop_reason
                    yield out
                    num_output_tokens_so_far = next_total_toks
            except asyncio.CancelledError:
                # raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
                raise GeneratorExit(
                    "Decode engine was shut down during token generation"
                ) from None

        except EngineDeadError as e:
            logger.error(f"vLLM EngineDeadError: {e}")
            logger.warning("Initiating Dynamo Runtime shutdown.")
            self.runtime.shutdown()
            os._exit(1)
Alec's avatar
Alec committed
638
639
640
641


class DecodeWorkerHandler(BaseWorkerHandler):
    def __init__(
642
643
644
645
646
        self,
        runtime,
        component,
        engine,
        default_sampling_params,
647
        model_max_len: int | None = None,
648
        enable_multimodal: bool = False,
649
650
        generate_endpoint=None,
        config=None,
Alec's avatar
Alec committed
651
    ):
652
        super().__init__(
653
654
655
656
657
658
            runtime,
            component,
            engine,
            default_sampling_params,
            model_max_len,
            enable_multimodal,
659
660
            generate_endpoint,
            config,
661
        )
Alec's avatar
Alec committed
662

663
    async def generate(self, request, context):
664
665
666
        # Use context ID for request tracking and correlation
        request_id = context.id()
        logger.debug(f"Decode Request ID: {request_id}")
Alec's avatar
Alec committed
667

668
669
670
671
672
673
        # Extract and decode multimodal data if present
        multi_modal_data = await self._extract_multimodal_data(request)

        prompt = TokensPrompt(
            prompt_token_ids=request["token_ids"], multi_modal_data=multi_modal_data
        )
Alec's avatar
Alec committed
674

675
        # Build sampling params from request
676
677
678
        sampling_params = build_sampling_params(
            request, self.default_sampling_params, self.model_max_len
        )
679

680
681
682
        prefill_result = request.get("prefill_result")
        if prefill_result and isinstance(prefill_result, dict):
            kv_params = prefill_result.get("disaggregated_params", {}).get(
683
684
                "kv_transfer_params"
            )
685
686
687
688
689
690
691
        else:
            kv_params = None

        if kv_params is not None:
            if sampling_params.extra_args is None:
                sampling_params.extra_args = {}
            sampling_params.extra_args["kv_transfer_params"] = kv_params
692
693
694
            logger.debug(
                f"Using disaggregated params from prefill for request {request_id}"
            )
695
696
697
        prefill_prompt_tokens_details = (
            prefill_result.get("prompt_tokens_details") if prefill_result else None
        )
Alec's avatar
Alec committed
698

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
        # Extract LoRA request if present
        # Check if model name matches a loaded LoRA adapter
        lora_request = None
        model_name = request.get("model")

        if model_name and model_name in self.lora_id_for_name:
            lora_id = self.lora_id_for_name[model_name]
            lora_request = LoRARequest(
                lora_name=model_name,
                lora_int_id=lora_id,
                lora_path=self.lora_name_to_path[model_name],
            )
            logger.info(
                f"Decode request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id})"
            )
        else:
            logger.debug(
                f"Decode request {request_id} has no LoRA specified (model: {model_name})"
            )

Yan Ru Pei's avatar
Yan Ru Pei committed
719
720
        dp_rank = request.get("dp_rank", None)

721
722
723
        async with self._abort_monitor(context, request_id):
            try:
                async for tok in self.generate_tokens(
724
725
726
727
728
                    prompt,
                    sampling_params,
                    request_id,
                    data_parallel_rank=dp_rank,
                    lora_request=lora_request,
729
                ):
730
731
732
733
                    if prefill_result is not None and "completion_usage" in tok:
                        tok["completion_usage"][
                            "prompt_tokens_details"
                        ] = prefill_prompt_tokens_details
734
735
736
737
738
739
                    yield tok
            except EngineDeadError as e:
                logger.error(f"vLLM EngineDeadError: {e}")
                logger.warning("Initiating Dynamo Runtime shutdown.")
                self.runtime.shutdown()
                os._exit(1)
Alec's avatar
Alec committed
740
741
742


class PrefillWorkerHandler(BaseWorkerHandler):
743
744
745
746
747
748
749
    def __init__(
        self,
        runtime,
        component,
        engine,
        default_sampling_params,
        model_max_len: int | None = None,
750
        enable_multimodal: bool = False,
751
752
        generate_endpoint=None,
        config=None,
753
754
    ):
        super().__init__(
755
756
757
758
759
760
            runtime,
            component,
            engine,
            default_sampling_params,
            model_max_len,
            enable_multimodal,
761
762
            generate_endpoint,
            config,
763
        )
Alec's avatar
Alec committed
764

765
    async def generate(self, request, context):
766
767
768
        # Use context ID for request tracking and correlation with decode phase
        request_id = context.id()
        logger.debug(f"Prefill Request ID: {request_id}")
769

770
771
772
        # Extract and decode multimodal data if present
        multi_modal_data = await self._extract_multimodal_data(request)

773
        token_ids = request["token_ids"]
774
775
776
        prompt = TokensPrompt(
            prompt_token_ids=token_ids, multi_modal_data=multi_modal_data
        )
777

778
        # Build sampling params from request using shared utility
779
780
781
        sampling_params = build_sampling_params(
            request, self.default_sampling_params, self.model_max_len
        )
782
783
784
785
786
787
788

        # Configure for prefill-only mode with remote decode
        if sampling_params.extra_args is None:
            sampling_params.extra_args = {}
        sampling_params.extra_args["kv_transfer_params"] = {
            "do_remote_decode": True,
        }
789
790
791
792
793
794
795
796
797
798
        sampling_params_defaults = {
            "do_remote_prefill": False,
            "remote_engine_id": None,
            "remote_block_ids": None,
            "remote_host": None,
            "remote_port": None,
        }
        # Add only missing keys
        for k, v in sampling_params_defaults.items():
            sampling_params.extra_args["kv_transfer_params"].setdefault(k, v)
799
800
801
        # Override for prefill: only generate 1 token
        sampling_params.max_tokens = 1
        sampling_params.min_tokens = 1
Alec's avatar
Alec committed
802

803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
        # Extract LoRA request if present
        # Check if model name matches a loaded LoRA adapter
        lora_request = None
        model_name = request.get("model")

        if model_name and model_name in self.lora_id_for_name:
            lora_id = self.lora_id_for_name[model_name]
            lora_request = LoRARequest(
                lora_name=model_name,
                lora_int_id=lora_id,
                lora_path=self.lora_name_to_path[model_name],
            )
            logger.info(
                f"Prefill request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id}), "
                f"path: {self.lora_name_to_path[model_name]}"
            )
        else:
            logger.debug(
                f"Prefill request {request_id} has no LoRA specified (model: {model_name})"
            )

Yan Ru Pei's avatar
Yan Ru Pei committed
824
825
        dp_rank = request.get("dp_rank", None)

826
827
        async with self._abort_monitor(context, request_id, is_prefill=True):
            try:
Yan Ru Pei's avatar
Yan Ru Pei committed
828
                gen = self.engine_client.generate(
829
830
831
832
833
                    prompt,
                    sampling_params,
                    request_id,
                    data_parallel_rank=dp_rank,
                    lora_request=lora_request,
Yan Ru Pei's avatar
Yan Ru Pei committed
834
                )
835
836
837
838
839
840
841
842
843
            except EngineDeadError as e:
                logger.error(f"vLLM EngineDeadError: {e}")
                logger.warning("Initiating Dynamo Runtime shutdown.")
                self.runtime.shutdown()
                os._exit(1)

            try:
                async for res in gen:
                    logger.debug(f"kv transfer params: {res.kv_transfer_params}")
844
845
846
847
848

                    token_ids = res.outputs[0].token_ids if res.outputs else []

                    output: Dict[str, Any] = {
                        "token_ids": list(token_ids),
849
                        "disaggregated_params": (
850
851
                            {"kv_transfer_params": res.kv_transfer_params}
                            if res.kv_transfer_params
852
                            else None
853
                        ),
854
855
856
                        "completion_usage": BaseWorkerHandler._build_completion_usage(
                            request_output=res
                        ),
857
858
                    }

859
860
861
862
863
864
865
866
                    # Log prefill completion with LoRA info
                    if lora_request:
                        logger.info(
                            f"Prefill completed for request {request_id} with LoRA {lora_request.lora_name}: "
                            f"generated {len(token_ids)} token(s), "
                            f"has_kv_params={res.kv_transfer_params is not None}"
                        )

867
                    yield output
868
869
870
871
872
            except asyncio.CancelledError:
                # raise the error because we cannot migrate prefill requests
                raise GeneratorExit(
                    "Prefill engine was shut down during token generation"
                ) from None