test_prompt_embeds.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
End-to-end tests for prompt embeddings support in Dynamo.

These tests validate behavior that cannot be covered by Rust unit tests:
- Streaming responses with embeddings
- Python-side tensor decoding errors
- Usage statistics from worker (the v2.0.4 bug fix)
11
- Large payload handling through the local request path
12
13
14
15
- Concurrent request handling

Validation tests (base64, size limits, empty prompt) are covered by Rust unit tests
in lib/llm/src/protocols/openai/completions.rs
16
17

Run with: pytest tests/frontend/test_prompt_embeds.py -v
18
19
"""

20
21
from __future__ import annotations

22
23
24
25
import base64
import concurrent.futures
import io
import logging
26
27
28
import os
import shutil
from typing import Generator
29
30
31
32
33

import pytest
import torch
from openai import OpenAI

34
35
36
37
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
from tests.utils.payloads import check_models_api
from tests.utils.port_utils import ServicePorts

38
39
40
41
42
logger = logging.getLogger(__name__)

# Test model - small and fast for CI
TEST_MODEL = "Qwen/Qwen3-0.6B"

43
44
45
46
47
48
49
50
51
52
53
54
pytestmark = [
    pytest.mark.integration,
    pytest.mark.vllm,
    pytest.mark.nightly,
    pytest.mark.gpu_1,
    pytest.mark.model(TEST_MODEL),
]


class VllmPromptEmbedsWorkerProcess(ManagedProcess):
    """Vllm Worker process configured for prompt embeddings testing.

55
56
    Uses file-based KV store and TCP request plane. No NATS or etcd required:
    the file backend automatically defaults the event plane to ZMQ.
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    """

    def __init__(
        self,
        request,
        *,
        frontend_port: int,
        system_port: int,
        worker_id: str = "vllm-prompt-embeds-worker",
    ):
        self.worker_id = worker_id
        self.frontend_port = int(frontend_port)
        self.system_port = int(system_port)

        command = [
            "python3",
            "-m",
            "dynamo.vllm",
            "--model",
            TEST_MODEL,
            "--max-model-len",
            "4096",
79
            "--discovery-backend",
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            "file",
            "--request-plane",
            "tcp",
            "--enable-prompt-embeds",
            "--kv-events-config",
            '{"enable_kv_cache_events": false}',
        ]

        env = os.environ.copy()
        env["DYN_LOG"] = "debug"
        env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
        env["DYN_SYSTEM_PORT"] = str(self.system_port)

        log_dir = f"{request.node.name}_{worker_id}"

        try:
            shutil.rmtree(log_dir)
        except FileNotFoundError:
            pass

        super().__init__(
            command=command,
            env=env,
            health_check_urls=[
                (f"http://localhost:{self.frontend_port}/v1/models", check_models_api),
                (f"http://localhost:{self.system_port}/health", self.is_ready),
            ],
            timeout=500,
            display_output=True,
109
            terminate_all_matching_process_names=False,
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            stragglers=["VLLM::EngineCore"],
            straggler_commands=["-m dynamo.vllm"],
            log_dir=log_dir,
        )

    def is_ready(self, response) -> bool:
        try:
            status = (response.json() or {}).get("status")
        except ValueError:
            logger.warning("%s health response is not valid JSON", self.worker_id)
            return False

        is_ready = status == "ready"
        if is_ready:
            logger.info("%s status is ready", self.worker_id)
        else:
            logger.warning("%s status is not ready: %s", self.worker_id, status)
        return is_ready


@pytest.fixture(scope="function")
def start_services(
    request,
    file_storage_backend,
    dynamo_dynamic_ports: ServicePorts,
    predownload_models,
) -> Generator[ServicePorts, None, None]:
    """Start frontend and vllm worker processes for prompt embeds testing.

139
140
141
    Uses file-based KV store and TCP request plane. No NATS or etcd needed:
    the file backend automatically defaults the event plane to ZMQ, avoiding
    all external service dependencies and keeping tests simpler and faster.
142
143
144
145
146
147
148
149
150
151
152
153

    The `file_storage_backend` fixture sets up a temporary directory and
    configures DYN_FILE_KV environment variable.
    """
    _ = file_storage_backend  # Ensures temp dir is set up and DYN_FILE_KV is configured
    _ = predownload_models  # Ensures model is downloaded before starting services
    frontend_port = dynamo_dynamic_ports.frontend_port
    system_port = dynamo_dynamic_ports.system_ports[0]

    with DynamoFrontendProcess(
        request,
        frontend_port=frontend_port,
154
        terminate_all_matching_process_names=False,
155
        extra_args=["--discovery-backend", "file", "--request-plane", "tcp"],
156
157
158
159
160
161
162
163
164
165
    ):
        logger.info("Frontend started for prompt embeds tests")
        with VllmPromptEmbedsWorkerProcess(
            request,
            frontend_port=frontend_port,
            system_port=system_port,
        ):
            logger.info("Vllm Worker with prompt embeds started for tests")
            yield dynamo_dynamic_ports

166
167

@pytest.fixture
168
169
def dynamo_client(start_services: ServicePorts):
    """Create OpenAI client pointing to Dynamo frontend on the allocated port."""
170
171
    return OpenAI(
        api_key="EMPTY",
172
        base_url=f"http://localhost:{start_services.frontend_port}/v1",
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    )


def create_embeddings_base64(shape: tuple[int, ...]) -> str:
    """Create random embeddings tensor and return as base64-encoded PyTorch format."""
    embeddings = torch.randn(*shape, dtype=torch.float32)
    buffer = io.BytesIO()
    torch.save(embeddings, buffer)
    buffer.seek(0)
    return base64.b64encode(buffer.read()).decode("utf-8")


class TestPromptEmbedsE2E:
    """
    End-to-end tests for prompt embeddings.

    These tests require a running Dynamo instance with vLLM backend.
    They validate behavior that Rust unit tests cannot cover.
    """

    def test_streaming_with_embeddings(self, dynamo_client):
        """
        Test streaming responses work correctly with embeddings.

        This is E2E only - Rust tests can't verify streaming behavior.
        """
        embeddings_base64 = create_embeddings_base64((10, 1024))

        stream = dynamo_client.completions.create(
            model=TEST_MODEL,
            prompt="",
            max_tokens=10,
            stream=True,
            extra_body={"prompt_embeds": embeddings_base64},
        )

        chunks = list(stream)

        assert len(chunks) > 0, "Should receive at least one chunk"
        # Last chunk should have finish_reason
        if chunks[-1].choices:
            assert chunks[-1].choices[0].finish_reason is not None

    def test_invalid_tensor_data_rejected(self, dynamo_client):
        """
        Test that invalid tensor data is properly rejected by Python decoder.

        This tests the Python-side torch.load() error handling, which
        Rust validation cannot cover (Rust only checks base64 and size).
        """
        # Create data that passes Rust validation (valid base64, >100 bytes)
        # but fails Python torch.load()
        invalid_data = b"this is not a valid pytorch tensor format!" * 10
        invalid_base64 = base64.b64encode(invalid_data).decode("utf-8")

        with pytest.raises(Exception) as exc_info:
            dynamo_client.completions.create(
                model=TEST_MODEL,
                prompt="",
                max_tokens=5,
                extra_body={"prompt_embeds": invalid_base64},
            )

        error_msg = str(exc_info.value).lower()
        assert any(
            keyword in error_msg
            for keyword in ["pytorch", "tensor", "invalid", "decode", "error"]
        ), f"Expected tensor decode error, got: {error_msg}"

    def test_usage_prompt_tokens_not_zero(self, dynamo_client):
        """
        CRITICAL REGRESSION TEST: Ensure prompt_tokens is correctly reported.

        This validates the v2.0.4 fix where prompt_tokens was incorrectly
        reported as 0 when using embeddings. The worker extracts sequence
        length from tensor shape and includes it in completion_usage.

        Rust tests cannot verify this - it requires E2E validation.
        """
        sequence_length = 20
        embeddings_base64 = create_embeddings_base64((sequence_length, 1024))

        response = dynamo_client.completions.create(
            model=TEST_MODEL,
            prompt="",
            max_tokens=3,
            extra_body={"prompt_embeds": embeddings_base64},
        )

        assert response.usage is not None, "Should have usage statistics"
        assert (
            response.usage.prompt_tokens != 0
        ), "BUG REGRESSION: prompt_tokens is 0! This was the bug in v2.0.3."
        assert (
            response.usage.prompt_tokens == sequence_length
        ), f"Expected prompt_tokens={sequence_length}, got {response.usage.prompt_tokens}"
        assert response.usage.total_tokens == (
            response.usage.prompt_tokens + response.usage.completion_tokens
        ), "total_tokens should equal prompt_tokens + completion_tokens"

273
    def test_large_embeddings_through_local_request_path(self, dynamo_client):
274
        """
275
        Test large embeddings are handled correctly through the local request path.
276

277
278
        This validates the E2E frontend-to-worker path handles large embedding
        payloads. Rust unit tests can't test this E2E path.
279
        """
280
        # Create ~7MB embeddings (well under 10MB limit, but large enough to stress the path)
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        large_shape = (1700, 1024)  # ~6.6MB of float32 data
        large_embeds = torch.randn(large_shape, dtype=torch.float32)

        buffer = io.BytesIO()
        torch.save(large_embeds, buffer)
        buffer.seek(0)
        large_bytes = buffer.read()
        large_base64 = base64.b64encode(large_bytes).decode("utf-8")

        logger.info(
            f"Testing large embeddings: {len(large_bytes)/1024/1024:.2f}MB decoded"
        )

        response = dynamo_client.completions.create(
            model=TEST_MODEL,
            prompt="",
            max_tokens=5,
            extra_body={"prompt_embeds": large_base64},
        )

        assert response.choices, "Large embeddings should produce valid response"
        assert len(large_bytes) < 10 * 1024 * 1024, "Test data should be under 10MB"

    def test_concurrent_embeddings_requests(self, dynamo_client):
        """
        Test concurrent requests with embeddings are handled correctly.

        This validates the worker can handle multiple embedding requests
        simultaneously without race conditions or resource conflicts.
        """
        embeddings_base64 = create_embeddings_base64((10, 1024))

        def send_request():
            return dynamo_client.completions.create(
                model=TEST_MODEL,
                prompt="",
                max_tokens=5,
                extra_body={"prompt_embeds": embeddings_base64},
            )

        # Send 5 concurrent requests
        with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
            futures = [executor.submit(send_request) for _ in range(5)]
            results = [f.result() for f in concurrent.futures.as_completed(futures)]

        assert len(results) == 5, "All concurrent requests should complete"
        for response in results:
            assert response.choices, "Each response should have choices"
            assert len(response.choices[0].text) > 0, "Each response should have text"