test_renderer.py 11.6 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import io
5
6
7
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock

8
import pybase64
9
import pytest
10
import torch
11

12
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
13
from vllm.inputs.data import is_embeds_prompt
14
15
16
17
18


@dataclass
class MockModelConfig:
    max_model_len: int = 100
19
    encoder_config: dict | None = None
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


class MockTokenizerResult:
    def __init__(self, input_ids):
        self.input_ids = input_ids


@pytest.fixture
def mock_model_config():
    return MockModelConfig()


@pytest.fixture
def mock_tokenizer():
    tokenizer = MagicMock()
    return tokenizer


@pytest.fixture
def mock_async_tokenizer():
    async_tokenizer = AsyncMock()
    return async_tokenizer


@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
46
47
48
49
50
    return CompletionRenderer(
        model_config=mock_model_config,
        tokenizer=mock_tokenizer,
        async_tokenizer_pool={},
    )
51
52
53
54
55
56
57
58


class TestRenderPrompt:
    """Test Category A: Basic Functionality Tests"""

    @pytest.mark.asyncio
    async def test_token_input(self, renderer):
        tokens = [101, 7592, 2088]
59
        results = await renderer.render_prompt(
60
61
            prompt_or_prompts=tokens, config=RenderConfig(max_length=100)
        )
62
63
64
65
66
67
68

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

    @pytest.mark.asyncio
    async def test_token_list_input(self, renderer):
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
69
        results = await renderer.render_prompt(
70
71
            prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)
        )
72
73
74
75
76
77
78
79

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

    @pytest.mark.asyncio
    async def test_text_input(self, renderer, mock_async_tokenizer):
80
81
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
82

83
        results = await renderer.render_prompt(
84
85
            prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
        )
86
87
88
89
90
91
92

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        mock_async_tokenizer.assert_called_once()

    @pytest.mark.asyncio
    async def test_text_list_input(self, renderer, mock_async_tokenizer):
93
94
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
95
96
97

        text_list_input = ["Hello world", "How are you?", "Good morning"]
        results = await renderer.render_prompt(
98
99
            prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100)
        )
100
101
102
103
104
105
106
107

        assert len(results) == 3
        for result in results:
            assert result["prompt_token_ids"] == [101, 7592, 2088]
        assert mock_async_tokenizer.call_count == 3

    @pytest.mark.asyncio
    async def test_no_truncation(self, renderer, mock_async_tokenizer):
108
109
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
110

111
        results = await renderer.render_prompt(
112
113
            prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
        )
114
115
116

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
117
118
119
120
        assert (
            "truncation" not in call_args.kwargs
            or call_args.kwargs["truncation"] is False
        )
121
122
123
124

    @pytest.mark.asyncio
    async def test_truncation_positive(self, renderer, mock_async_tokenizer):
        mock_async_tokenizer.return_value = MockTokenizerResult(
125
126
127
            [101, 7592, 2088]
        )  # Truncated
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
128

129
130
131
132
        results = await renderer.render_prompt(
            prompt_or_prompts="Hello world",
            config=RenderConfig(max_length=100, truncate_prompt_tokens=50),
        )
133
134
135
136
137
138

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
        assert call_args.kwargs["truncation"] is True
        assert call_args.kwargs["max_length"] == 50

139
140
141
142
    @pytest.mark.asyncio
    async def test_truncation_negative(self, renderer, mock_async_tokenizer):
        # Test that negative truncation uses model's max_model_len
        mock_async_tokenizer.return_value = MockTokenizerResult(
143
144
145
            [101, 7592, 2088]
        )  # Truncated to max_model_len
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
146

147
148
149
150
        results = await renderer.render_prompt(
            prompt_or_prompts="Hello world",
            config=RenderConfig(max_length=200, truncate_prompt_tokens=-1),
        )
151
152
153
154
155
156

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
        assert call_args.kwargs["truncation"] is True
        assert call_args.kwargs["max_length"] == 100  # model's max_model_len

157
158
159
    @pytest.mark.asyncio
    async def test_token_truncation_last_elements(self, renderer):
        # Test that token truncation keeps the last N elements
160
161
162
163
164
        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
        results = await renderer.render_prompt(
            prompt_or_prompts=long_tokens,
            config=RenderConfig(max_length=100, truncate_prompt_tokens=5),
        )
165
166
167
168
169
170
171
172
173
174

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

    @pytest.mark.asyncio
    async def test_max_length_exceeded(self, renderer):
        long_tokens = list(range(150))  # Exceeds max_model_len=100

        with pytest.raises(ValueError, match="maximum context length"):
175
176
177
            await renderer.render_prompt(
                prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100)
            )
178
179
180
181

    @pytest.mark.asyncio
    async def test_no_tokenizer_for_text(self, mock_model_config):
        renderer_no_tokenizer = CompletionRenderer(
182
183
            model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={}
        )
184
185
186

        with pytest.raises(ValueError, match="No tokenizer available"):
            await renderer_no_tokenizer.render_prompt(
187
188
                prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
            )
189
190
191

    @pytest.mark.asyncio
    async def test_token_input_with_needs_detokenization(
192
193
        self, renderer, mock_async_tokenizer
    ):
194
195
196
197
        # When needs_detokenization=True for token inputs, renderer should
        # use the async tokenizer to decode and include the original text
        # in the returned prompt object.
        mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
198
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
199
200
201
202

        tokens = [1, 2, 3, 4]
        results = await renderer.render_prompt(
            prompt_or_prompts=tokens,
203
            config=RenderConfig(needs_detokenization=True),
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
        assert results[0]["prompt"] == "decoded text"
        mock_async_tokenizer.decode.assert_awaited_once()


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

    @pytest.mark.asyncio
    async def test_single_prompt_embed(self, renderer):
        # Create a test tensor
        test_tensor = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
227
228
229
            prompt_embeds=embed_bytes,
            config=RenderConfig(cache_salt="test_salt"),
        )
230
231
232
233
234
235
236
237
238
239
240
241
242

        assert len(results) == 1
        assert is_embeds_prompt(results[0])
        assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
        assert results[0]["cache_salt"] == "test_salt"

    @pytest.mark.asyncio
    async def test_multiple_prompt_embeds(self, renderer):
        # Create multiple test tensors
        test_tensors = [
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]
243
        embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
244
245

        results = await renderer.render_prompt_and_embeds(
246
247
248
            prompt_embeds=embed_bytes_list,
            config=RenderConfig(),
        )
249
250
251
252
253
254
255
256
257
258
259
260
261

        assert len(results) == 2
        for i, result in enumerate(results):
            assert is_embeds_prompt(result)
            assert torch.allclose(result["prompt_embeds"], test_tensors[i])

    @pytest.mark.asyncio
    async def test_prompt_embed_truncation(self, renderer):
        # Create tensor with more tokens than truncation limit
        test_tensor = torch.randn(20, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
262
263
264
            prompt_embeds=embed_bytes,
            config=RenderConfig(truncate_prompt_tokens=10),
        )
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        assert len(results) == 1
        # Should keep last 10 tokens
        expected = test_tensor[-10:]
        assert torch.allclose(results[0]["prompt_embeds"], expected)

    @pytest.mark.asyncio
    async def test_prompt_embed_different_dtypes(self, renderer):
        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
            test_tensor = torch.randn(5, 256, dtype=dtype)
            embed_bytes = self._create_test_embed_bytes(test_tensor)

            results = await renderer.render_prompt_and_embeds(
281
282
283
                prompt_embeds=embed_bytes,
                config=RenderConfig(),
            )
284
285
286
287
288
289
290
291
292
293
294

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

    @pytest.mark.asyncio
    async def test_prompt_embed_squeeze_batch_dim(self, renderer):
        # Test tensor with batch dimension gets squeezed
        test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
295
296
297
            prompt_embeds=embed_bytes,
            config=RenderConfig(),
        )
298
299
300
301
302
303

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

    @pytest.mark.asyncio
304
    async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
305
        # Set up text tokenization
306
307
        mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103])
        renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
308
309
310
311
312
313

        # Create embed
        test_tensor = torch.randn(5, 256, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(test_tensor)

        results = await renderer.render_prompt_and_embeds(
314
315
316
317
            prompt_or_prompts="Hello world",
            prompt_embeds=embed_bytes,
            config=RenderConfig(),
        )
318
319
320
321
322
323
324

        assert len(results) == 2
        # First should be embed prompt
        assert is_embeds_prompt(results[0])
        # Second should be tokens prompt
        assert "prompt_token_ids" in results[1]
        assert results[1]["prompt_token_ids"] == [101, 102, 103]