handlers.py 9.8 KB
Newer Older
Alec's avatar
Alec committed
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
6
import os
Alec's avatar
Alec committed
7
from abc import ABC, abstractmethod
8
from contextlib import asynccontextmanager
9
from typing import Any, AsyncGenerator, Dict
Alec's avatar
Alec committed
10
11
12

from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
13
from vllm.v1.engine.exceptions import EngineDeadError
Alec's avatar
Alec committed
14

15
from dynamo.llm import ZmqKvEventPublisher
Alec's avatar
Alec committed
16
17
from dynamo.runtime.logging import configure_dynamo_logging

18
from .engine_monitor import VllmEngineMonitor
Alec's avatar
Alec committed
19

Alec's avatar
Alec committed
20
21
22
23
configure_dynamo_logging()
logger = logging.getLogger(__name__)


24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def build_sampling_params(
    request: Dict[str, Any], default_sampling_params: Dict[str, Any]
) -> SamplingParams:
    """
    Build SamplingParams from a PreprocessedRequest.

    Args:
        request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
        default_sampling_params: Default sampling parameters to initialize with

    Returns:
        SamplingParams configured from the request
    """
    sampling_params = SamplingParams(**default_sampling_params)
    sampling_params.detokenize = False

    # Apply sampling_options
    for key, value in request["sampling_options"].items():
        if value is not None and hasattr(sampling_params, key):
            setattr(sampling_params, key, value)

    # Apply stop_conditions
    for key, value in request["stop_conditions"].items():
        if value is not None and hasattr(sampling_params, key):
48
49
50
            # Do not add stop key to sampling params - dynamo handles stop conditions directly
            if key == "stop":
                continue
51
52
53
54
55
            setattr(sampling_params, key, value)

    return sampling_params


Alec's avatar
Alec committed
56
57
58
59
60
class BaseWorkerHandler(ABC):
    """
    Request handler for the generate and clear_kv_blocks endpoints.
    """

61
62
    def __init__(self, runtime, component, engine, default_sampling_params):
        self.runtime = runtime
Alec's avatar
Alec committed
63
64
65
        self.component = component
        self.engine_client = engine
        self.default_sampling_params = default_sampling_params
66
        self.kv_publishers: list[ZmqKvEventPublisher] | None = None
67
        self.engine_monitor = VllmEngineMonitor(runtime, engine)
Alec's avatar
Alec committed
68
69

    @abstractmethod
70
    async def generate(self, request, context) -> AsyncGenerator[dict, None]:
Alec's avatar
Alec committed
71
72
        raise NotImplementedError

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    async def _monitor_abort(self, context, request_id, is_prefill):
        """Background task that monitors for context cancellation and aborts the request."""
        try:
            await context.async_killed_or_stopped()
            # If we reach here, the context was stopped or killed
            await self.engine_client.abort(request_id)
            logger.debug(
                f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
            )
        except asyncio.CancelledError:
            # Task was cancelled, normal cleanup if not aborted
            pass
        except Exception as e:
            logger.error(f"Error in abort monitor for request {request_id}: {e}")

    @asynccontextmanager
    async def _abort_monitor(self, context, request_id, is_prefill=False):
        """Context manager that creates and automatically cleans up an abort monitoring task."""
        task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
        try:
            yield task
        finally:
            # Cancel the abort monitoring task when exiting the context
            if not task.done():
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass

Alec's avatar
Alec committed
103
104
105
106
107
108
109
110
111
112
113
    async def clear_kv_blocks(self, request=None):
        try:
            await self.engine_client.reset_prefix_cache()
            yield {"status": "success", "message": "KV cache cleared"}
        except Exception as e:
            yield {"status": "error", "message": str(e)}

    def cleanup(self):
        """Override in subclasses if cleanup is needed."""
        pass

Yan Ru Pei's avatar
Yan Ru Pei committed
114
115
116
    async def generate_tokens(
        self, prompt, sampling_params, request_id, data_parallel_rank=None
    ):
117
        try:
Yan Ru Pei's avatar
Yan Ru Pei committed
118
119
120
121
122
123
            gen = self.engine_client.generate(
                prompt,
                sampling_params,
                request_id,
                data_parallel_rank=data_parallel_rank,
            )
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            num_output_tokens_so_far = 0
            try:
                async for res in gen:
                    # res is vllm's RequestOutput

                    if not res.outputs:
                        yield {"finish_reason": "error", "token_ids": []}
                        break

                    output = res.outputs[0]
                    next_total_toks = len(output.token_ids)
                    out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
                    if output.finish_reason:
                        out["finish_reason"] = output.finish_reason
                    if output.stop_reason:
                        out["stop_reason"] = output.stop_reason
                    yield out
                    num_output_tokens_so_far = next_total_toks
            except asyncio.CancelledError:
                # raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
                raise GeneratorExit(
                    "Decode engine was shut down during token generation"
                ) from None

        except EngineDeadError as e:
            logger.error(f"vLLM EngineDeadError: {e}")
            logger.warning("Initiating Dynamo Runtime shutdown.")
            self.runtime.shutdown()
            os._exit(1)
Alec's avatar
Alec committed
154
155
156
157


class DecodeWorkerHandler(BaseWorkerHandler):
    def __init__(
158
159
160
161
162
        self,
        runtime,
        component,
        engine,
        default_sampling_params,
Alec's avatar
Alec committed
163
    ):
164
        super().__init__(runtime, component, engine, default_sampling_params)
Alec's avatar
Alec committed
165

166
    async def generate(self, request, context):
167
168
169
        # Use context ID for request tracking and correlation
        request_id = context.id()
        logger.debug(f"Decode Request ID: {request_id}")
Alec's avatar
Alec committed
170
171
172

        prompt = TokensPrompt(prompt_token_ids=request["token_ids"])

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        # Build sampling params from request
        sampling_params = build_sampling_params(request, self.default_sampling_params)

        # Extract disaggregated_params from request (set by prefill router in Rust frontend)
        disaggregated_params = request.get("disaggregated_params")
        if disaggregated_params:
            # Prefill was performed - use the disaggregated params
            if sampling_params.extra_args is None:
                sampling_params.extra_args = {}
            sampling_params.extra_args["kv_transfer_params"] = disaggregated_params.get(
                "kv_transfer_params"
            )
            logger.debug(
                f"Using disaggregated params from prefill for request {request_id}"
            )
Alec's avatar
Alec committed
188

Yan Ru Pei's avatar
Yan Ru Pei committed
189
190
        dp_rank = request.get("dp_rank", None)

191
192
193
        async with self._abort_monitor(context, request_id):
            try:
                async for tok in self.generate_tokens(
Yan Ru Pei's avatar
Yan Ru Pei committed
194
                    prompt, sampling_params, request_id, data_parallel_rank=dp_rank
195
196
197
198
199
200
201
                ):
                    yield tok
            except EngineDeadError as e:
                logger.error(f"vLLM EngineDeadError: {e}")
                logger.warning("Initiating Dynamo Runtime shutdown.")
                self.runtime.shutdown()
                os._exit(1)
Alec's avatar
Alec committed
202
203
204


class PrefillWorkerHandler(BaseWorkerHandler):
205
206
    def __init__(self, runtime, component, engine, default_sampling_params):
        super().__init__(runtime, component, engine, default_sampling_params)
Alec's avatar
Alec committed
207

208
    async def generate(self, request, context):
209
210
211
        # Use context ID for request tracking and correlation with decode phase
        request_id = context.id()
        logger.debug(f"Prefill Request ID: {request_id}")
212

213
214
215
        token_ids = request["token_ids"]
        prompt = TokensPrompt(prompt_token_ids=token_ids)

216
217
218
219
220
221
222
223
224
225
226
227
        # Build sampling params from request using shared utility
        sampling_params = build_sampling_params(request, self.default_sampling_params)

        # Configure for prefill-only mode with remote decode
        if sampling_params.extra_args is None:
            sampling_params.extra_args = {}
        sampling_params.extra_args["kv_transfer_params"] = {
            "do_remote_decode": True,
        }
        # Override for prefill: only generate 1 token
        sampling_params.max_tokens = 1
        sampling_params.min_tokens = 1
Alec's avatar
Alec committed
228

Yan Ru Pei's avatar
Yan Ru Pei committed
229
230
        dp_rank = request.get("dp_rank", None)

231
232
        async with self._abort_monitor(context, request_id, is_prefill=True):
            try:
Yan Ru Pei's avatar
Yan Ru Pei committed
233
234
235
                gen = self.engine_client.generate(
                    prompt, sampling_params, request_id, data_parallel_rank=dp_rank
                )
236
237
238
239
240
241
242
243
244
            except EngineDeadError as e:
                logger.error(f"vLLM EngineDeadError: {e}")
                logger.warning("Initiating Dynamo Runtime shutdown.")
                self.runtime.shutdown()
                os._exit(1)

            try:
                async for res in gen:
                    logger.debug(f"kv transfer params: {res.kv_transfer_params}")
245
246
247
248
249

                    token_ids = res.outputs[0].token_ids if res.outputs else []

                    output: Dict[str, Any] = {
                        "token_ids": list(token_ids),
250
                        "disaggregated_params": (
251
252
                            {"kv_transfer_params": res.kv_transfer_params}
                            if res.kv_transfer_params
253
                            else None
254
255
256
257
                        ),
                    }

                    yield output
258
259
260
261
262
            except asyncio.CancelledError:
                # raise the error because we cannot migrate prefill requests
                raise GeneratorExit(
                    "Prefill engine was shut down during token generation"
                ) from None