processor.py 12.6 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# SPDX-License-Identifier: Apache-2.0

import argparse
import asyncio
import json
import logging
import os
import signal
import sys
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union

import uvloop
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from vllm.outputs import RequestOutput
20
21
from vllm.tokenizers import TokenizerLike as AnyTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
22

23
from dynamo.llm import ModelInput, ModelType, register_model
24
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
25
26
27
28
29
from dynamo.runtime.logging import configure_dynamo_logging

# To import example local module
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from utils.args import Config, base_parse_args, parse_endpoint
30
from utils.chat_message_utils import extract_user_text
31
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
32
33
34
35
36
37
from utils.protocol import (
    MultiModalInput,
    MultiModalRequest,
    MyRequestOutput,
    vLLMMultimodalRequest,
)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

configure_dynamo_logging()
logger = logging.getLogger(__name__)


class RequestType(Enum):
    CHAT = "chat"
    COMPLETION = "completion"


class Processor(ProcessMixIn):
    """
    vLLM pre and post processing
    """

    @classmethod
    def parse_args(cls) -> Tuple[argparse.Namespace, Config]:
55
56
57
        DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
        DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.processor.generate"
        DEFAULT_DOWNSTREAM_ENDPOINT = f"dyn://{DYN_NAMESPACE}.encoder.generate"
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

        parser = FlexibleArgumentParser(
            description="vLLM based processor for Dynamo LLM."
        )
        parser.add_argument(
            "--prompt-template",
            type=str,
            required=True,
            help=(
                "Different multi-modal models expect the prompt to contain different special media prompts. "
                "The processor will use this argument to construct the final prompt. "
                "User prompt will replace '<prompt>' in the provided template. "
                "For example, if the user prompt is 'please describe the image' and the prompt template is "
                "'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
                "'USER: <image> please describe the image ASSISTANT:'."
            ),
        )
        parser.add_argument(
            "--endpoint",
            type=str,
            default=DEFAULT_ENDPOINT,
            help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_ENDPOINT}'",
        )
        parser.add_argument(
            "--downstream-endpoint",
            type=str,
            default=DEFAULT_DOWNSTREAM_ENDPOINT,
            help=f"The endpoint string of the downstream encoder in 'dyn://namespace.component.endpoint' format. Default: '{DEFAULT_DOWNSTREAM_ENDPOINT}'",
        )

        args, config = base_parse_args(parser)

        return args, config

92
93
94
95
96
97
98
    def __init__(
        self,
        args: argparse.Namespace,
        engine_args: AsyncEngineArgs,
        encode_worker_client: Client,
    ):
        self.encode_worker_client = encode_worker_client
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        self.prompt_template = args.prompt_template
        self.engine_args = engine_args
        self.model_config = self.engine_args.create_model_config()
        self.default_sampling_params = self.model_config.get_diff_sampling_param()
        self.tokenizer = self._create_tokenizer(self.engine_args)
        self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
        self.completions_processor = CompletionsProcessor(
            self.tokenizer, self.model_config
        )

    def cleanup(self):
        pass

    def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
        """Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
        model_path = engine_args.model

        # Create the base tokenizer with VLLM's typical settings
        base_tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            padding_side="left",
            truncation_side="left",
            use_fast=True,  # VLLM might use the fast tokenizer for efficiency
        )
        return base_tokenizer

    # Main method to parse the request and send the request to the vllm worker.
    async def _generate(
        self,
        raw_request: Union[CompletionRequest, ChatCompletionRequest],
130
        multimodal_input: MultiModalInput,
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        request_type: RequestType,
    ):
        request_id = str(uuid.uuid4().hex)
        logger.debug(f"Got raw request: {raw_request}")
        (
            request,
            conversation,
            engine_prompt,
            sampling_params,
        ) = await self._parse_raw_request(raw_request)

        worker_request = vLLMMultimodalRequest(
            engine_prompt=engine_prompt,
            sampling_params=sampling_params,
            request_id=request_id,
146
            multimodal_input=multimodal_input,
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        )

        # model_dump_json() serializes the request to JSON string
        # This API could accept Pydantic class, but SamplingParams
        # in vLLMMultimodalRequest is not a Pydantic class and will
        # cause TypeError: unsupported type SamplingParams
        response_generator = await self.encode_worker_client.round_robin(
            worker_request.model_dump_json()
        )

        output = self._generate_responses(response_generator, request_type)

        # Stream the processed responses
        async for response in await self._stream_response(
            request, output, request_id, conversation
        ):
            yield response

    # This method is used to process the responses from the engine generator.
    async def _generate_responses(
        self,
        response_generator: AsyncIterator[RequestOutput],
        request_type: RequestType,
    ):
        async for resp in response_generator:
            # Deserialize the response from the engine
            # Creates correct vLLM objects for each field
            output = MyRequestOutput.model_validate_json(resp.data())

            # OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
            request_output = RequestOutput(
                request_id=output.request_id,
                prompt=output.prompt,
                prompt_token_ids=output.prompt_token_ids,
                prompt_logprobs=output.prompt_logprobs,
                outputs=output.outputs,
                finished=output.finished,
                metrics=output.metrics,
            )

            if request_type == RequestType.CHAT:
                # For chat requests, yield the request_output directly.
                yield request_output
            else:
                raise NotImplementedError(
                    f"Request type {request_type} not implemented"
                )

    # The generate endpoint will be used by the frontend to handle incoming requests.
    async def generate(self, raw_request: MultiModalRequest):
        logger.debug(f"Got raw request: {raw_request}")
        if not isinstance(raw_request, MultiModalRequest):
            # If the request is not MultiModalRequest, convert it to MultiModalRequest
            raw_request = MultiModalRequest.model_validate(raw_request)

        # Ensure the configured template includes the placeholder
        template = self.prompt_template
        if "<prompt>" not in template:
            raise ValueError("prompt_template must contain '<prompt>' placeholder")

207
        user_text = extract_user_text(raw_request.messages)
208
209
210
211
212
213
214
215

        prompt = template.replace("<prompt>", user_text)

        msg = {
            "role": "user",
            "content": prompt,
        }

216
217
218
        # Set stream=True - the http frontend will handle aggregation of
        # streamed chunks into a single http response, or stream them
        # back as SSE responses based on the stream flag in the request.
219
220
221
        chat_request = ChatCompletionRequest(
            model=raw_request.model,
            messages=[msg],
222
            stream=True,
223
224
225
226
            max_tokens=raw_request.max_tokens,
            temperature=raw_request.temperature,
            request_id=str(uuid.uuid4()),
        )
227
        multimodal_input = MultiModalInput()
228
229
230
231

        for message in raw_request.messages:
            for item in message.content:
                if item.type == "image_url":
232
233
234
235
236
                    multimodal_input.image_url = item.image_url.url
                elif item.type == "video_url":
                    if multimodal_input.image_url is not None:
                        raise ValueError("Cannot provide both image and video URLs")
                    multimodal_input.video_url = item.video_url.url
237
238
239
240
241
242
243
244
245
246
247
248
249
250
                elif item.type == "audio_url":
                    if (
                        multimodal_input.image_url is not None
                        or multimodal_input.video_url is not None
                    ):
                        raise ValueError("Cannot mix image, video and audio URLs")
                    multimodal_input.audio_url = item.audio_url.url

        if (
            multimodal_input.image_url is None
            and multimodal_input.video_url is None
            and multimodal_input.audio_url is None
        ):
            raise ValueError("Either image URL or video URL or audio URL is required")
251

252
253
254
        async for response in self._generate(
            chat_request, multimodal_input, RequestType.CHAT
        ):
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
            logger.debug(
                f"Generated response type {type(response)}, content: {response}"
            )
            # reconstructing back the OpenAI chat response as dynamo egress expects it
            if response.startswith("data: [DONE]"):
                break
            response = json.loads(response.lstrip("data: "))
            yield response


async def graceful_shutdown(runtime):
    """
    By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
    However, in-flight requests will still be processed until they are finished.
    After all in-flight requests are finished, the `serve_endpoint` functions will return
    and the engine will be shutdown by Python's garbage collector.
    """
    logging.info("Received shutdown signal, shutting down DistributedRuntime")
    runtime.shutdown()
    logging.info("DistributedRuntime shutdown complete")


277
@dynamo_worker()
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
async def worker(runtime: DistributedRuntime):
    # Runtime setup
    # Set up signal handler for graceful shutdown
    loop = asyncio.get_running_loop()

    def signal_handler():
        asyncio.create_task(graceful_shutdown(runtime))

    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, signal_handler)

    logging.info("Signal handlers set up for graceful shutdown")

    # worker setup
    args, config = Processor.parse_args()
    await init(runtime, args, config)


async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Config):
    """
    Instantiate and serve
    """

    component = runtime.namespace(config.namespace).component(config.component)

    generate_endpoint = component.endpoint(config.endpoint)

305
306
307
308
309
310
311
312
313
314
315
316
317
318
    parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
        args.downstream_endpoint
    )
    encode_worker_client = (
        await runtime.namespace(parsed_namespace)
        .component(parsed_component_name)
        .endpoint(parsed_endpoint_name)
        .client()
    )

    handler = Processor(args, config.engine_args, encode_worker_client)

    logger.info("Waiting for Encoder Worker Instances ...")
    await encode_worker_client.wait_for_instances()
319
320

    # Register the endpoint as entrypoint to a model
321
    await register_model(
322
323
        ModelInput.Text,  # Custom processor is used and this type bypasses SDK processor
        ModelType.Chat,
324
325
326
327
328
329
330
331
332
333
        generate_endpoint,
        config.model,
        config.served_model_name,
        kv_cache_block_size=config.engine_args.block_size,
    )

    logger.info(f"Starting to serve the {args.endpoint} endpoint...")

    try:
        await asyncio.gather(
334
335
336
            generate_endpoint.serve_endpoint(
                handler.generate, metrics_labels=[("model", config.model)]
            ),
337
338
339
340
341
342
343
344
345
346
347
        )
    except Exception as e:
        logger.error(f"Failed to serve endpoints: {e}")
        raise
    finally:
        handler.cleanup()


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())