vllm_prepost_worker.py 2.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Lightweight token-based worker for vLLM frontend pre/post integration tests."""

from __future__ import annotations

import asyncio
import json
import os
from pathlib import Path
from typing import Any

import uvloop
from transformers import AutoTokenizer

from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime
from tests.frontend.test_prepost import OUTPUTS_INTERVAL_20
from tests.frontend.test_vllm_prepost_integration import CAPTURE_PATH_ENV
from tests.utils.constants import QWEN


class VllmPrepostTestHandler:
    """Captures tokenized requests and streams a fixed token response."""

    def __init__(self, model_name: str = QWEN):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def _write_capture(self, request: dict[str, Any]) -> None:
        capture_path = os.environ.get(CAPTURE_PATH_ENV)
        if not capture_path:
            return

        token_ids = request.get("token_ids", [])
        captured = {
            "model": request.get("model"),
            "token_ids": token_ids,
            "stop_conditions": request.get("stop_conditions"),
            "sampling_options": request.get("sampling_options"),
            "output_options": request.get("output_options"),
            "eos_token_ids": request.get("eos_token_ids"),
            "decoded_prompt": self.tokenizer.decode(
                token_ids,
                skip_special_tokens=False,
            ),
        }

        path = Path(capture_path)
        path.parent.mkdir(parents=True, exist_ok=True)
        tmp_path = path.with_suffix(path.suffix + ".tmp")
        tmp_path.write_text(json.dumps(captured), encoding="utf-8")
        tmp_path.replace(path)

    async def generate(self, request: dict[str, Any], context):
        self._write_capture(request)

        for output in OUTPUTS_INTERVAL_20:
            chunk = {"token_ids": list(output.token_ids)}
            if output.finish_reason is not None:
                chunk["finish_reason"] = output.finish_reason
            if output.stop_reason is not None:
                chunk["stop_reason"] = output.stop_reason
            yield chunk


async def main():
    """Register a token-based chat model and stream deterministic responses."""

    runtime = DistributedRuntime(
        asyncio.get_running_loop(), "etcd", "tcp", enable_nats=False
    )
    endpoint = runtime.endpoint("test.vllm-prepost.generate")
    await register_model(
        ModelInput.Tokens,
        ModelType.Chat,
        endpoint,
        QWEN,
        model_name=QWEN,
    )

    handler = VllmPrepostTestHandler(QWEN)
    await endpoint.serve_endpoint(handler.generate)


if __name__ == "__main__":
    uvloop.run(main())