test_renderer.py 11.7 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import io
5
6
7
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock

8
import pybase64
9
import pytest
10
import torch
11

12
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
13
from vllm.inputs.data import is_embeds_prompt
14
15
16
17
18


@dataclass
class MockModelConfig:
    max_model_len: int = 100
19
    encoder_config: dict | None = None
20
    enable_prompt_embeds: bool = True
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


class MockTokenizerResult:
    def __init__(self, input_ids):
        self.input_ids = input_ids


@pytest.fixture
def mock_model_config():
    return MockModelConfig()


@pytest.fixture
def mock_tokenizer():
    tokenizer = MagicMock()
    return tokenizer


@pytest.fixture
def mock_async_tokenizer():
    async_tokenizer = AsyncMock()
    return async_tokenizer


@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
47
48
49
50
51
    return CompletionRenderer(
        model_config=mock_model_config,
        tokenizer=mock_tokenizer,
        async_tokenizer_pool={},
    )
52
53
54
55
56
57
58
59


class TestRenderPrompt:
    """Test Category A: Basic Functionality Tests"""

    @pytest.mark.asyncio
    async def test_token_input(self, renderer):
        tokens = [101, 7592, 2088]
60
        results = await renderer.render_prompt(
61
62
            prompt_or_prompts=tokens, config=RenderConfig(max_length=100)
        )
63
64
65
66
67
68
69

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

    @pytest.mark.asyncio
    async def test_token_list_input(self, renderer):
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
70
        results = await renderer.render_prompt(
71
72
            prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)
        )
73
74
75
76
77
78
79
80

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

    @pytest.mark.asyncio
    async def test_text_input(self, renderer, mock_async_tokenizer):
81
82
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
83

84
        results = await renderer.render_prompt(
85
86
            prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
        )
87
88
89
90
91
92
93

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        mock_async_tokenizer.assert_called_once()

    @pytest.mark.asyncio
    async def test_text_list_input(self, renderer, mock_async_tokenizer):
94
95
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
96
97
98

        text_list_input = ["Hello world", "How are you?", "Good morning"]
        results = await renderer.render_prompt(
99
100
            prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100)
        )
101
102
103
104
105
106
107
108

        assert len(results) == 3
        for result in results:
            assert result["prompt_token_ids"] == [101, 7592, 2088]
        assert mock_async_tokenizer.call_count == 3

    @pytest.mark.asyncio
    async def test_no_truncation(self, renderer, mock_async_tokenizer):
109
110
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
111

112
        results = await renderer.render_prompt(
113
114
            prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
        )
115
116
117

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
118
119
120
121
        assert (
            "truncation" not in call_args.kwargs
            or call_args.kwargs["truncation"] is False
        )
122
123
124
125

    @pytest.mark.asyncio
    async def test_truncation_positive(self, renderer, mock_async_tokenizer):
        mock_async_tokenizer.return_value = MockTokenizerResult(
126
127
128
            [101, 7592, 2088]
        )  # Truncated
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
129

130
131
132
133
        results = await renderer.render_prompt(
            prompt_or_prompts="Hello world",
            config=RenderConfig(max_length=100, truncate_prompt_tokens=50),
        )
134
135
136
137
138
139

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
        assert call_args.kwargs["truncation"] is True
        assert call_args.kwargs["max_length"] == 50

140
141
142
143
    @pytest.mark.asyncio
    async def test_truncation_negative(self, renderer, mock_async_tokenizer):
        # Test that negative truncation uses model's max_model_len
        mock_async_tokenizer.return_value = MockTokenizerResult(
144
145
146
            [101, 7592, 2088]
        )  # Truncated to max_model_len
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
147

148
149
150
151
        results = await renderer.render_prompt(
            prompt_or_prompts="Hello world",
            config=RenderConfig(max_length=200, truncate_prompt_tokens=-1),
        )
152
153
154
155
156
157

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
        assert call_args.kwargs["truncation"] is True
        assert call_args.kwargs["max_length"] == 100  # model's max_model_len

158
159
160
    @pytest.mark.asyncio
    async def test_token_truncation_last_elements(self, renderer):
        # Test that token truncation keeps the last N elements
161
162
163
164
165
        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
        results = await renderer.render_prompt(
            prompt_or_prompts=long_tokens,
            config=RenderConfig(max_length=100, truncate_prompt_tokens=5),
        )
166
167
168
169
170
171
172
173
174
175

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

    @pytest.mark.asyncio
    async def test_max_length_exceeded(self, renderer):
        long_tokens = list(range(150))  # Exceeds max_model_len=100

        with pytest.raises(ValueError, match="maximum context length"):
176
177
178
            await renderer.render_prompt(
                prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100)
            )
179
180
181
182

    @pytest.mark.asyncio
    async def test_no_tokenizer_for_text(self, mock_model_config):
        renderer_no_tokenizer = CompletionRenderer(
183
184
            model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={}
        )
185
186
187

        with pytest.raises(ValueError, match="No tokenizer available"):
            await renderer_no_tokenizer.render_prompt(
188
189
                prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
            )
190
191
192

    @pytest.mark.asyncio
    async def test_token_input_with_needs_detokenization(
193
194
        self, renderer, mock_async_tokenizer
    ):
195
196
197
198
        # When needs_detokenization=True for token inputs, renderer should
        # use the async tokenizer to decode and include the original text
        # in the returned prompt object.
        mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
199
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
200
201
202
203

        tokens = [1, 2, 3, 4]
        results = await renderer.render_prompt(
            prompt_or_prompts=tokens,
204
            config=RenderConfig(needs_detokenization=True),
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
        assert results[0]["prompt"] == "decoded text"
        mock_async_tokenizer.decode.assert_awaited_once()


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

    @pytest.mark.asyncio
    async def test_single_prompt_embed(self, renderer):
        # Create a test tensor
        test_tensor = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
228
229
230
            prompt_embeds=embed_bytes,
            config=RenderConfig(cache_salt="test_salt"),
        )
231
232
233
234
235
236
237
238
239
240
241
242
243

        assert len(results) == 1
        assert is_embeds_prompt(results[0])
        assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
        assert results[0]["cache_salt"] == "test_salt"

    @pytest.mark.asyncio
    async def test_multiple_prompt_embeds(self, renderer):
        # Create multiple test tensors
        test_tensors = [
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]
244
        embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
245
246

        results = await renderer.render_prompt_and_embeds(
247
248
249
            prompt_embeds=embed_bytes_list,
            config=RenderConfig(),
        )
250
251
252
253
254
255
256
257
258
259
260
261
262

        assert len(results) == 2
        for i, result in enumerate(results):
            assert is_embeds_prompt(result)
            assert torch.allclose(result["prompt_embeds"], test_tensors[i])

    @pytest.mark.asyncio
    async def test_prompt_embed_truncation(self, renderer):
        # Create tensor with more tokens than truncation limit
        test_tensor = torch.randn(20, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
263
264
265
            prompt_embeds=embed_bytes,
            config=RenderConfig(truncate_prompt_tokens=10),
        )
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        assert len(results) == 1
        # Should keep last 10 tokens
        expected = test_tensor[-10:]
        assert torch.allclose(results[0]["prompt_embeds"], expected)

    @pytest.mark.asyncio
    async def test_prompt_embed_different_dtypes(self, renderer):
        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
            test_tensor = torch.randn(5, 256, dtype=dtype)
            embed_bytes = self._create_test_embed_bytes(test_tensor)

            results = await renderer.render_prompt_and_embeds(
282
283
284
                prompt_embeds=embed_bytes,
                config=RenderConfig(),
            )
285
286
287
288
289
290
291
292
293
294
295

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

    @pytest.mark.asyncio
    async def test_prompt_embed_squeeze_batch_dim(self, renderer):
        # Test tensor with batch dimension gets squeezed
        test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
296
297
298
            prompt_embeds=embed_bytes,
            config=RenderConfig(),
        )
299
300
301
302
303
304

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

    @pytest.mark.asyncio
305
    async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
306
        # Set up text tokenization
307
308
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
309
310
311
312
313
314

        # Create embed
        test_tensor = torch.randn(5, 256, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
315
316
317
318
            prompt_or_prompts="Hello world",
            prompt_embeds=embed_bytes,
            config=RenderConfig(),
        )
319
320
321
322
323
324
325

        assert len(results) == 2
        # First should be embed prompt
        assert is_embeds_prompt(results[0])
        # Second should be tokens prompt
        assert "prompt_token_ids" in results[1]
        assert results[1]["prompt_token_ids"] == [101, 102, 103]