test_completions.py 15.7 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5
from collections.abc import Sequence
6
7
8
9
10
11
12
from dataclasses import dataclass
from typing import Any

import pybase64
import pytest
import torch

13
14
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
15
16
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
17
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

MODEL_NAME = "openai-community/gpt2"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    tokenizer_revision = None
    tokenizer_mode = "auto"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False
39
    is_encoder_decoder: bool = False
40
    is_multimodal_model: bool = False
41
42


43
44
45
46
47
@dataclass
class MockParallelConfig:
    _api_process_rank: int = 0


48
49
50
@dataclass
class MockVllmConfig:
    model_config: MockModelConfig
51
    parallel_config: MockParallelConfig
52
53


54
55
56
57
58
59
60
61
62
63
64
65
66
@dataclass
class DummyTokenizer:
    truncation_side: str = "left"
    max_chars_per_token: int = 1

    def __post_init__(self) -> None:
        self._captured_encode_kwargs: dict = {}

    def decode(self, tokens: list[int]):
        return str(tokens)

    def encode(self, text: str, **kwargs):
        self._captured_encode_kwargs = kwargs
67

68
69
70
71
72
        in_length = len(text)
        truncation = kwargs.get("truncation")
        max_length = kwargs.get("max_length")
        if truncation and max_length is not None:
            return list(range(min(in_length, max_length)))
73

74
        return list(range(in_length))
75
76


77
78
79
80
81
82
83
def _build_renderer(
    model_config: MockModelConfig,
    *,
    truncation_side: str = "left",
    max_chars_per_token: int = 1,
):
    renderer = HfRenderer(
84
        MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
85
86
87
88
89
90
91
92
        tokenizer=(
            None
            if model_config.skip_tokenizer_init
            else DummyTokenizer(
                truncation_side=truncation_side,
                max_chars_per_token=max_chars_per_token,
            )
        ),
93
94
    )

95
96
    return renderer

97

98
def _preprocess_prompt(
99
    model_config: ModelConfig,
100
101
102
103
104
105
    prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
    return [
        (
            prompt
            if isinstance(prompt, bytes)
106
            else parse_model_prompt(model_config, prompt)
107
108
        )
        for prompt in prompt_to_seq(prompt_or_prompts)
109
110
111
    ]


112
class TestValidatePrompt:
113
114
115
    def test_empty_input(self):
        renderer = _build_renderer(MockModelConfig())

116
        with pytest.raises(ValueError, match="at least one prompt"):
117
            renderer.render_prompts(_preprocess_prompt(renderer.model_config, []))
118

119
120
121
    def test_invalid_type(self):
        renderer = _build_renderer(MockModelConfig())

122
123
        with pytest.raises(TypeError, match="should be a list of integers"):
            renderer.render_prompts(
124
                _preprocess_prompt(renderer.model_config, [[1, 2], ["foo", "bar"]])  # type: ignore[arg-type]
125
            )
126
127
128


class TestRenderPrompt:
129
    def test_tokens_input(self):
130
131
        renderer = _build_renderer(MockModelConfig())

132
        tokens = [101, 7592, 2088]
133
134
135
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
136
        results = renderer.tokenize_prompts(
137
138
139
140
141
142
143
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

144
145
146
    def test_token_list_input(self):
        renderer = _build_renderer(MockModelConfig())

147
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
148
        prompts = renderer.render_prompts(
149
            _preprocess_prompt(renderer.model_config, token_lists)
150
        )
151
        results = renderer.tokenize_prompts(
152
153
154
155
156
157
158
159
160
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

161
162
    def test_text_input(self):
        renderer = _build_renderer(MockModelConfig())
163

164
        text_input = "x" * 10
165
        prompts = renderer.render_prompts(
166
            _preprocess_prompt(renderer.model_config, text_input)
167
        )
168
        results = renderer.tokenize_prompts(
169
170
171
172
173
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
174
        assert len(results[0]["prompt_token_ids"]) == 10
175

176
177
    def test_text_list_input(self):
        renderer = _build_renderer(MockModelConfig())
178

179
        text_list_input = ["x" * 10, "x" * 12, "x" * 14]
180
        prompts = renderer.render_prompts(
181
            _preprocess_prompt(renderer.model_config, text_list_input)
182
        )
183
        results = renderer.tokenize_prompts(
184
185
186
187
188
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
189
190
        for text_input, result in zip(text_list_input, results):
            assert len(result["prompt_token_ids"]) == len(text_input)
191

192
193
    def test_zero_truncation(self):
        renderer = _build_renderer(MockModelConfig())
194

195
        prompts = renderer.render_prompts(
196
            _preprocess_prompt(renderer.model_config, "x" * 200)
197
        )
198
        results = renderer.tokenize_prompts(
199
            prompts,
200
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
201
202
203
        )

        assert len(results) == 1
204
        assert len(results[0]["prompt_token_ids"]) == 0
205

206
207
    def test_pos_truncation(self):
        renderer = _build_renderer(MockModelConfig())
208

209
        prompts = renderer.render_prompts(
210
            _preprocess_prompt(renderer.model_config, "x" * 200)
211
        )
212
        results = renderer.tokenize_prompts(
213
            prompts,
214
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
215
216
217
        )

        assert len(results) == 1
218
219
220
221
222
        assert len(results[0]["prompt_token_ids"]) == 50

    def test_neg_truncation(self):
        renderer = _build_renderer(MockModelConfig())

223
        prompts = renderer.render_prompts(
224
            _preprocess_prompt(renderer.model_config, "x" * 200)
225
        )
226
        results = renderer.tokenize_prompts(
227
            prompts,
228
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
229
230
231
        )

        assert len(results) == 1
232
233
234
235
        assert len(results[0]["prompt_token_ids"]) == 100  # max_total_tokens

    def test_truncation_left(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="left")
236
237

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
238
        prompts = renderer.render_prompts(
239
            _preprocess_prompt(renderer.model_config, long_tokens)
240
        )
241
        results = renderer.tokenize_prompts(
242
            prompts,
243
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
244
245
246
247
248
249
        )

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

250
251
252
253
    def test_truncation_right(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="right")

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
254
        prompts = renderer.render_prompts(
255
            _preprocess_prompt(renderer.model_config, long_tokens)
256
        )
257
258
259
260
        results = renderer.tokenize_prompts(
            prompts,
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
        )
261

262
263
264
        assert len(results) == 1
        # Should keep the first 5 tokens: [100, 101, 102, 103, 104]
        assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
265

266
267
268
269
270
    def test_text_max_length_exceeded_obvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)

        # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
271
        prompts = renderer.render_prompts(
272
            _preprocess_prompt(renderer.model_config, long_tokens)
273
        )
274
275
276

        with pytest.raises(
            ValueError,
277
            match="maximum context length is",
278
279
        ):
            renderer.tokenize_prompts(
280
281
282
283
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

284
        # Should not even attempt tokenization
285
        assert renderer.tokenizer._captured_encode_kwargs == {}
286
287
288
289
290
291

    def test_text_max_length_exceeded_nonobvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)

        # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
292
        prompts = renderer.render_prompts(
293
            _preprocess_prompt(renderer.model_config, long_tokens)
294
        )
295
296
297

        with pytest.raises(
            ValueError,
298
            match="maximum context length is",
299
300
301
302
303
304
305
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

        # Should only tokenize the first max_total_tokens + 1 tokens
306
307
        assert renderer.tokenizer._captured_encode_kwargs["truncation"] is True
        assert renderer.tokenizer._captured_encode_kwargs["max_length"] == 101
308
309
310
311
312

    def test_token_max_length_exceeded(self):
        renderer = _build_renderer(MockModelConfig())

        long_tokens = list(range(150))  # Exceeds max_total_tokens=100
313
        prompts = renderer.render_prompts(
314
            _preprocess_prompt(renderer.model_config, long_tokens)
315
        )
316
317
318

        with pytest.raises(
            ValueError,
319
            match="maximum context length is",
320
321
322
323
324
325
326
327
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
            )

    def test_no_tokenizer_for_text(self):
        renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
328

329
        prompts = renderer.render_prompts(
330
            _preprocess_prompt(renderer.model_config, "Hello world")
331
        )
332
333

        with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
334
            renderer.tokenize_prompts(
335
336
337
338
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

339
    def test_tokens_input_with_needs_detokenization(self):
340
        renderer = _build_renderer(MockModelConfig())
341
342

        tokens = [1, 2, 3, 4]
343
344
345
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
346
        results = renderer.tokenize_prompts(
347
348
            prompts,
            TokenizeParams(
349
                max_total_tokens=100,
350
351
352
353
354
355
                needs_detokenization=True,
            ),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
356
        assert results[0]["prompt"] == "[1, 2, 3, 4]"
357
358
359
360
361
362
363
364
365
366


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

367
368
369
    def test_single_prompt_embed(self):
        renderer = _build_renderer(MockModelConfig())

370
        # Create a test tensor
371
372
        tensor_input = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(tensor_input)
373

374
        prompts = renderer.render_prompts(
375
            _preprocess_prompt(renderer.model_config, embed_bytes)
376
        )
377
        results = renderer.tokenize_prompts(
378
            prompts,
379
            TokenizeParams(max_total_tokens=100),
380
381
382
        )

        assert len(results) == 1
383
384
385
386
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)

    def test_multiple_prompt_embeds(self):
        renderer = _build_renderer(MockModelConfig())
387
388

        # Create multiple test tensors
389
        tensor_inputs = [
390
391
392
393
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]

394
395
        prompts = renderer.render_prompts(
            _preprocess_prompt(
396
                renderer.model_config,
397
398
                [self._create_test_embed_bytes(t) for t in tensor_inputs],
            )
399
        )
400
        results = renderer.tokenize_prompts(
401
            prompts,
402
            TokenizeParams(max_total_tokens=100),
403
404
405
406
        )

        assert len(results) == 2
        for i, result in enumerate(results):
407
408
409
410
            assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])

    def test_prompt_embed_truncation(self):
        renderer = _build_renderer(MockModelConfig())
411
412

        # Create tensor with more tokens than truncation limit
413
        tensor_input = torch.randn(20, 768, dtype=torch.float32)
414

415
416
        prompts = renderer.render_prompts(
            _preprocess_prompt(
417
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
418
            )
419
420
        )
        results = renderer.tokenize_prompts(
421
422
            prompts,
            TokenizeParams(
423
                max_total_tokens=100,
424
425
426
427
428
429
                truncate_prompt_tokens=10,
            ),
        )

        assert len(results) == 1
        # Should keep last 10 tokens
430
431
432
433
434
        expected = tensor_input[-10:]
        assert torch.equal(results[0]["prompt_embeds"], expected)

    def test_prompt_embed_different_dtypes(self):
        renderer = _build_renderer(MockModelConfig())
435
436
437
438
439

        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
440
            tensor_input = torch.randn(5, 256, dtype=dtype)
441

442
443
            prompts = renderer.render_prompts(
                _preprocess_prompt(
444
                    renderer.model_config, self._create_test_embed_bytes(tensor_input)
445
                )
446
447
            )
            results = renderer.tokenize_prompts(
448
                prompts,
449
                TokenizeParams(max_total_tokens=100),
450
451
452
453
454
            )

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

455
456
457
    def test_prompt_embed_squeeze_batch_dim(self):
        renderer = _build_renderer(MockModelConfig())

458
        # Test tensor with batch dimension gets squeezed
459
        tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
460

461
462
        prompts = renderer.render_prompts(
            _preprocess_prompt(
463
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
464
            )
465
466
        )
        results = renderer.tokenize_prompts(
467
            prompts,
468
            TokenizeParams(max_total_tokens=100),
469
470
471
472
473
474
        )

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

475
476
    def test_both_prompts_and_embeds(self):
        renderer = _build_renderer(MockModelConfig())
477

478
479
        text_input = "Hello world"
        tensor_input = torch.randn(5, 256, dtype=torch.float32)
480

481
482
        prompts = renderer.render_prompts(
            _preprocess_prompt(
483
                renderer.model_config,
484
485
                [text_input, self._create_test_embed_bytes(tensor_input)],
            )
486
        )
487
        results = renderer.tokenize_prompts(
488
            prompts,
489
            TokenizeParams(max_total_tokens=100),
490
491
492
        )

        assert len(results) == 2
493
494
495
496
497
        # First should be tokens prompt
        assert "prompt_token_ids" in results[0]
        assert len(results[0]["prompt_token_ids"]) == len(text_input)
        # Second should be embed prompt
        assert torch.equal(results[1]["prompt_embeds"], tensor_input)