test_completions.py 15.7 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5
from collections.abc import Sequence
6
7
8
9
10
11
12
from dataclasses import dataclass
from typing import Any

import pybase64
import pytest
import torch

13
14
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
15
16
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
17
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

MODEL_NAME = "openai-community/gpt2"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    tokenizer_revision = None
    tokenizer_mode = "auto"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False
39
    is_encoder_decoder: bool = False
40
    is_multimodal_model: bool = False
41
    renderer_num_workers: int = 1
42
43


44
45
46
47
48
@dataclass
class MockParallelConfig:
    _api_process_rank: int = 0


49
50
51
@dataclass
class MockVllmConfig:
    model_config: MockModelConfig
52
    parallel_config: MockParallelConfig
53
54


55
56
57
58
59
60
61
62
63
64
65
66
67
@dataclass
class DummyTokenizer:
    truncation_side: str = "left"
    max_chars_per_token: int = 1

    def __post_init__(self) -> None:
        self._captured_encode_kwargs: dict = {}

    def decode(self, tokens: list[int]):
        return str(tokens)

    def encode(self, text: str, **kwargs):
        self._captured_encode_kwargs = kwargs
68

69
70
71
72
73
        in_length = len(text)
        truncation = kwargs.get("truncation")
        max_length = kwargs.get("max_length")
        if truncation and max_length is not None:
            return list(range(min(in_length, max_length)))
74

75
        return list(range(in_length))
76
77


78
79
80
81
82
83
84
def _build_renderer(
    model_config: MockModelConfig,
    *,
    truncation_side: str = "left",
    max_chars_per_token: int = 1,
):
    renderer = HfRenderer(
85
        MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
86
87
88
89
90
91
92
93
        tokenizer=(
            None
            if model_config.skip_tokenizer_init
            else DummyTokenizer(
                truncation_side=truncation_side,
                max_chars_per_token=max_chars_per_token,
            )
        ),
94
95
    )

96
97
    return renderer

98

99
def _preprocess_prompt(
100
    model_config: ModelConfig,
101
102
103
104
105
106
    prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
    return [
        (
            prompt
            if isinstance(prompt, bytes)
107
            else parse_model_prompt(model_config, prompt)
108
109
        )
        for prompt in prompt_to_seq(prompt_or_prompts)
110
111
112
    ]


113
class TestValidatePrompt:
114
115
116
    def test_empty_input(self):
        renderer = _build_renderer(MockModelConfig())

117
        with pytest.raises(ValueError, match="at least one prompt"):
118
            renderer.render_prompts(_preprocess_prompt(renderer.model_config, []))
119

120
121
122
    def test_invalid_type(self):
        renderer = _build_renderer(MockModelConfig())

123
124
        with pytest.raises(TypeError, match="should be a list of integers"):
            renderer.render_prompts(
125
                _preprocess_prompt(renderer.model_config, [[1, 2], ["foo", "bar"]])  # type: ignore[arg-type]
126
            )
127
128
129


class TestRenderPrompt:
130
    def test_tokens_input(self):
131
132
        renderer = _build_renderer(MockModelConfig())

133
        tokens = [101, 7592, 2088]
134
135
136
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
137
        results = renderer.tokenize_prompts(
138
139
140
141
142
143
144
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

145
146
147
    def test_token_list_input(self):
        renderer = _build_renderer(MockModelConfig())

148
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
149
        prompts = renderer.render_prompts(
150
            _preprocess_prompt(renderer.model_config, token_lists)
151
        )
152
        results = renderer.tokenize_prompts(
153
154
155
156
157
158
159
160
161
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

162
163
    def test_text_input(self):
        renderer = _build_renderer(MockModelConfig())
164

165
        text_input = "x" * 10
166
        prompts = renderer.render_prompts(
167
            _preprocess_prompt(renderer.model_config, text_input)
168
        )
169
        results = renderer.tokenize_prompts(
170
171
172
173
174
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
175
        assert len(results[0]["prompt_token_ids"]) == 10
176

177
178
    def test_text_list_input(self):
        renderer = _build_renderer(MockModelConfig())
179

180
        text_list_input = ["x" * 10, "x" * 12, "x" * 14]
181
        prompts = renderer.render_prompts(
182
            _preprocess_prompt(renderer.model_config, text_list_input)
183
        )
184
        results = renderer.tokenize_prompts(
185
186
187
188
189
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
190
191
        for text_input, result in zip(text_list_input, results):
            assert len(result["prompt_token_ids"]) == len(text_input)
192

193
194
    def test_zero_truncation(self):
        renderer = _build_renderer(MockModelConfig())
195

196
        prompts = renderer.render_prompts(
197
            _preprocess_prompt(renderer.model_config, "x" * 200)
198
        )
199
        results = renderer.tokenize_prompts(
200
            prompts,
201
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
202
203
204
        )

        assert len(results) == 1
205
        assert len(results[0]["prompt_token_ids"]) == 0
206

207
208
    def test_pos_truncation(self):
        renderer = _build_renderer(MockModelConfig())
209

210
        prompts = renderer.render_prompts(
211
            _preprocess_prompt(renderer.model_config, "x" * 200)
212
        )
213
        results = renderer.tokenize_prompts(
214
            prompts,
215
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
216
217
218
        )

        assert len(results) == 1
219
220
221
222
223
        assert len(results[0]["prompt_token_ids"]) == 50

    def test_neg_truncation(self):
        renderer = _build_renderer(MockModelConfig())

224
        prompts = renderer.render_prompts(
225
            _preprocess_prompt(renderer.model_config, "x" * 200)
226
        )
227
        results = renderer.tokenize_prompts(
228
            prompts,
229
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
230
231
232
        )

        assert len(results) == 1
233
234
235
236
        assert len(results[0]["prompt_token_ids"]) == 100  # max_total_tokens

    def test_truncation_left(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="left")
237
238

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
239
        prompts = renderer.render_prompts(
240
            _preprocess_prompt(renderer.model_config, long_tokens)
241
        )
242
        results = renderer.tokenize_prompts(
243
            prompts,
244
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
245
246
247
248
249
250
        )

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

251
252
253
254
    def test_truncation_right(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="right")

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
255
        prompts = renderer.render_prompts(
256
            _preprocess_prompt(renderer.model_config, long_tokens)
257
        )
258
259
260
261
        results = renderer.tokenize_prompts(
            prompts,
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
        )
262

263
264
265
        assert len(results) == 1
        # Should keep the first 5 tokens: [100, 101, 102, 103, 104]
        assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
266

267
268
269
270
271
    def test_text_max_length_exceeded_obvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)

        # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
272
        prompts = renderer.render_prompts(
273
            _preprocess_prompt(renderer.model_config, long_tokens)
274
        )
275
276
277

        with pytest.raises(
            ValueError,
278
            match="maximum context length is",
279
280
        ):
            renderer.tokenize_prompts(
281
282
283
284
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

285
        # Should not even attempt tokenization
286
        assert renderer.tokenizer._captured_encode_kwargs == {}
287
288
289
290
291
292

    def test_text_max_length_exceeded_nonobvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)

        # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
293
        prompts = renderer.render_prompts(
294
            _preprocess_prompt(renderer.model_config, long_tokens)
295
        )
296
297
298

        with pytest.raises(
            ValueError,
299
            match="maximum context length is",
300
301
302
303
304
305
306
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

        # Should only tokenize the first max_total_tokens + 1 tokens
307
308
        assert renderer.tokenizer._captured_encode_kwargs["truncation"] is True
        assert renderer.tokenizer._captured_encode_kwargs["max_length"] == 101
309
310
311
312
313

    def test_token_max_length_exceeded(self):
        renderer = _build_renderer(MockModelConfig())

        long_tokens = list(range(150))  # Exceeds max_total_tokens=100
314
        prompts = renderer.render_prompts(
315
            _preprocess_prompt(renderer.model_config, long_tokens)
316
        )
317
318
319

        with pytest.raises(
            ValueError,
320
            match="maximum context length is",
321
322
323
324
325
326
327
328
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
            )

    def test_no_tokenizer_for_text(self):
        renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
329

330
        prompts = renderer.render_prompts(
331
            _preprocess_prompt(renderer.model_config, "Hello world")
332
        )
333
334

        with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
335
            renderer.tokenize_prompts(
336
337
338
339
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

340
    def test_tokens_input_with_needs_detokenization(self):
341
        renderer = _build_renderer(MockModelConfig())
342
343

        tokens = [1, 2, 3, 4]
344
345
346
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
347
        results = renderer.tokenize_prompts(
348
349
            prompts,
            TokenizeParams(
350
                max_total_tokens=100,
351
352
353
354
355
356
                needs_detokenization=True,
            ),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
357
        assert results[0]["prompt"] == "[1, 2, 3, 4]"
358
359
360
361
362
363
364
365
366
367


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

368
369
370
    def test_single_prompt_embed(self):
        renderer = _build_renderer(MockModelConfig())

371
        # Create a test tensor
372
373
        tensor_input = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(tensor_input)
374

375
        prompts = renderer.render_prompts(
376
            _preprocess_prompt(renderer.model_config, embed_bytes)
377
        )
378
        results = renderer.tokenize_prompts(
379
            prompts,
380
            TokenizeParams(max_total_tokens=100),
381
382
383
        )

        assert len(results) == 1
384
385
386
387
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)

    def test_multiple_prompt_embeds(self):
        renderer = _build_renderer(MockModelConfig())
388
389

        # Create multiple test tensors
390
        tensor_inputs = [
391
392
393
394
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]

395
396
        prompts = renderer.render_prompts(
            _preprocess_prompt(
397
                renderer.model_config,
398
399
                [self._create_test_embed_bytes(t) for t in tensor_inputs],
            )
400
        )
401
        results = renderer.tokenize_prompts(
402
            prompts,
403
            TokenizeParams(max_total_tokens=100),
404
405
406
407
        )

        assert len(results) == 2
        for i, result in enumerate(results):
408
409
410
411
            assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])

    def test_prompt_embed_truncation(self):
        renderer = _build_renderer(MockModelConfig())
412
413

        # Create tensor with more tokens than truncation limit
414
        tensor_input = torch.randn(20, 768, dtype=torch.float32)
415

416
417
        prompts = renderer.render_prompts(
            _preprocess_prompt(
418
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
419
            )
420
421
        )
        results = renderer.tokenize_prompts(
422
423
            prompts,
            TokenizeParams(
424
                max_total_tokens=100,
425
426
427
428
429
430
                truncate_prompt_tokens=10,
            ),
        )

        assert len(results) == 1
        # Should keep last 10 tokens
431
432
433
434
435
        expected = tensor_input[-10:]
        assert torch.equal(results[0]["prompt_embeds"], expected)

    def test_prompt_embed_different_dtypes(self):
        renderer = _build_renderer(MockModelConfig())
436
437
438
439
440

        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
441
            tensor_input = torch.randn(5, 256, dtype=dtype)
442

443
444
            prompts = renderer.render_prompts(
                _preprocess_prompt(
445
                    renderer.model_config, self._create_test_embed_bytes(tensor_input)
446
                )
447
448
            )
            results = renderer.tokenize_prompts(
449
                prompts,
450
                TokenizeParams(max_total_tokens=100),
451
452
453
454
455
            )

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

456
457
458
    def test_prompt_embed_squeeze_batch_dim(self):
        renderer = _build_renderer(MockModelConfig())

459
        # Test tensor with batch dimension gets squeezed
460
        tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
461

462
463
        prompts = renderer.render_prompts(
            _preprocess_prompt(
464
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
465
            )
466
467
        )
        results = renderer.tokenize_prompts(
468
            prompts,
469
            TokenizeParams(max_total_tokens=100),
470
471
472
473
474
475
        )

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

476
477
    def test_both_prompts_and_embeds(self):
        renderer = _build_renderer(MockModelConfig())
478

479
480
        text_input = "Hello world"
        tensor_input = torch.randn(5, 256, dtype=torch.float32)
481

482
483
        prompts = renderer.render_prompts(
            _preprocess_prompt(
484
                renderer.model_config,
485
486
                [text_input, self._create_test_embed_bytes(tensor_input)],
            )
487
        )
488
        results = renderer.tokenize_prompts(
489
            prompts,
490
            TokenizeParams(max_total_tokens=100),
491
492
493
        )

        assert len(results) == 2
494
495
496
497
498
        # First should be tokens prompt
        assert "prompt_token_ids" in results[0]
        assert len(results[0]["prompt_token_ids"]) == len(text_input)
        # Second should be embed prompt
        assert torch.equal(results[1]["prompt_embeds"], tensor_input)