test_completions.py 15.8 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5
from collections.abc import Sequence
6
7
8
9
10
11
12
from dataclasses import dataclass
from typing import Any

import pybase64
import pytest
import torch

13
14
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
15
16
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
17
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.tokenizers.registry import tokenizer_args_from_config

MODEL_NAME = "openai-community/gpt2"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    tokenizer_revision = None
    tokenizer_mode = "auto"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False
40
    is_encoder_decoder: bool = False
41
42


43
44
45
46
47
@dataclass
class MockVllmConfig:
    model_config: MockModelConfig


48
49
50
51
52
53
54
55
56
57
58
59
60
@dataclass
class DummyTokenizer:
    truncation_side: str = "left"
    max_chars_per_token: int = 1

    def __post_init__(self) -> None:
        self._captured_encode_kwargs: dict = {}

    def decode(self, tokens: list[int]):
        return str(tokens)

    def encode(self, text: str, **kwargs):
        self._captured_encode_kwargs = kwargs
61

62
63
64
65
66
        in_length = len(text)
        truncation = kwargs.get("truncation")
        max_length = kwargs.get("max_length")
        if truncation and max_length is not None:
            return list(range(min(in_length, max_length)))
67

68
        return list(range(in_length))
69
70


71
72
73
74
75
76
77
def _build_renderer(
    model_config: MockModelConfig,
    *,
    truncation_side: str = "left",
    max_chars_per_token: int = 1,
):
    _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
78

79
    renderer = HfRenderer(
80
        MockVllmConfig(model_config),
81
82
83
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )

84
85
86
87
88
89
90
91
    if not model_config.skip_tokenizer_init:
        renderer._tokenizer = DummyTokenizer(
            truncation_side=truncation_side,
            max_chars_per_token=max_chars_per_token,
        )

    return renderer

92

93
94
95
96
97
98
99
100
101
102
103
def _preprocess_prompt(
    mdoel_config: ModelConfig,
    prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
    return [
        (
            prompt
            if isinstance(prompt, bytes)
            else parse_model_prompt(mdoel_config, prompt)
        )
        for prompt in prompt_to_seq(prompt_or_prompts)
104
105
106
    ]


107
class TestValidatePrompt:
108
109
110
    def test_empty_input(self):
        renderer = _build_renderer(MockModelConfig())

111
        with pytest.raises(ValueError, match="at least one prompt"):
112
            renderer.render_prompts(_preprocess_prompt(renderer.model_config, []))
113

114
115
116
    def test_invalid_type(self):
        renderer = _build_renderer(MockModelConfig())

117
118
        with pytest.raises(TypeError, match="should be a list of integers"):
            renderer.render_prompts(
119
                _preprocess_prompt(renderer.model_config, [[1, 2], ["foo", "bar"]])  # type: ignore[arg-type]
120
            )
121
122
123


class TestRenderPrompt:
124
125
126
    def test_token_input(self):
        renderer = _build_renderer(MockModelConfig())

127
        tokens = [101, 7592, 2088]
128
129
130
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
131
        results = renderer.tokenize_prompts(
132
133
134
135
136
137
138
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

139
140
141
    def test_token_list_input(self):
        renderer = _build_renderer(MockModelConfig())

142
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
143
        prompts = renderer.render_prompts(
144
            _preprocess_prompt(renderer.model_config, token_lists)
145
        )
146
        results = renderer.tokenize_prompts(
147
148
149
150
151
152
153
154
155
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

156
157
    def test_text_input(self):
        renderer = _build_renderer(MockModelConfig())
158

159
        text_input = "x" * 10
160
        prompts = renderer.render_prompts(
161
            _preprocess_prompt(renderer.model_config, text_input)
162
        )
163
        results = renderer.tokenize_prompts(
164
165
166
167
168
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
169
        assert len(results[0]["prompt_token_ids"]) == 10
170

171
172
    def test_text_list_input(self):
        renderer = _build_renderer(MockModelConfig())
173

174
        text_list_input = ["x" * 10, "x" * 12, "x" * 14]
175
        prompts = renderer.render_prompts(
176
            _preprocess_prompt(renderer.model_config, text_list_input)
177
        )
178
        results = renderer.tokenize_prompts(
179
180
181
182
183
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
184
185
        for text_input, result in zip(text_list_input, results):
            assert len(result["prompt_token_ids"]) == len(text_input)
186

187
188
    def test_zero_truncation(self):
        renderer = _build_renderer(MockModelConfig())
189

190
        prompts = renderer.render_prompts(
191
            _preprocess_prompt(renderer.model_config, "x" * 200)
192
        )
193
        results = renderer.tokenize_prompts(
194
            prompts,
195
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
196
197
198
        )

        assert len(results) == 1
199
        assert len(results[0]["prompt_token_ids"]) == 0
200

201
202
    def test_pos_truncation(self):
        renderer = _build_renderer(MockModelConfig())
203

204
        prompts = renderer.render_prompts(
205
            _preprocess_prompt(renderer.model_config, "x" * 200)
206
        )
207
        results = renderer.tokenize_prompts(
208
            prompts,
209
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
210
211
212
        )

        assert len(results) == 1
213
214
215
216
217
        assert len(results[0]["prompt_token_ids"]) == 50

    def test_neg_truncation(self):
        renderer = _build_renderer(MockModelConfig())

218
        prompts = renderer.render_prompts(
219
            _preprocess_prompt(renderer.model_config, "x" * 200)
220
        )
221
        results = renderer.tokenize_prompts(
222
            prompts,
223
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
224
225
226
        )

        assert len(results) == 1
227
228
229
230
        assert len(results[0]["prompt_token_ids"]) == 100  # max_total_tokens

    def test_truncation_left(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="left")
231
232

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
233
        prompts = renderer.render_prompts(
234
            _preprocess_prompt(renderer.model_config, long_tokens)
235
        )
236
        results = renderer.tokenize_prompts(
237
            prompts,
238
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
239
240
241
242
243
244
        )

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

245
246
247
248
    def test_truncation_right(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="right")

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
249
        prompts = renderer.render_prompts(
250
            _preprocess_prompt(renderer.model_config, long_tokens)
251
        )
252
253
254
255
        results = renderer.tokenize_prompts(
            prompts,
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
        )
256

257
258
259
        assert len(results) == 1
        # Should keep the first 5 tokens: [100, 101, 102, 103, 104]
        assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
260

261
262
263
264
265
    def test_text_max_length_exceeded_obvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)

        # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
266
        prompts = renderer.render_prompts(
267
            _preprocess_prompt(renderer.model_config, long_tokens)
268
        )
269
270
271
272
273
274

        with pytest.raises(
            ValueError,
            match="input characters and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
275
276
277
278
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

279
280
281
282
283
284
285
286
        # Should not even attempt tokenization
        assert renderer._tokenizer._captured_encode_kwargs == {}

    def test_text_max_length_exceeded_nonobvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)

        # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
287
        prompts = renderer.render_prompts(
288
            _preprocess_prompt(renderer.model_config, long_tokens)
289
        )
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

        # Should only tokenize the first max_total_tokens + 1 tokens
        assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
        assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101

    def test_token_max_length_exceeded(self):
        renderer = _build_renderer(MockModelConfig())

        long_tokens = list(range(150))  # Exceeds max_total_tokens=100
308
        prompts = renderer.render_prompts(
309
            _preprocess_prompt(renderer.model_config, long_tokens)
310
        )
311
312
313
314
315
316
317
318
319
320
321
322

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
            )

    def test_no_tokenizer_for_text(self):
        renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
323

324
        prompts = renderer.render_prompts(
325
            _preprocess_prompt(renderer.model_config, "Hello world")
326
        )
327
328

        with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
329
            renderer.tokenize_prompts(
330
331
332
333
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

334
335
    def test_token_input_with_needs_detokenization(self):
        renderer = _build_renderer(MockModelConfig())
336
337

        tokens = [1, 2, 3, 4]
338
339
340
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
341
        results = renderer.tokenize_prompts(
342
343
            prompts,
            TokenizeParams(
344
                max_total_tokens=100,
345
346
347
348
349
350
                needs_detokenization=True,
            ),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
351
        assert results[0]["prompt"] == "[1, 2, 3, 4]"
352
353
354
355
356
357
358
359
360
361


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

362
363
364
    def test_single_prompt_embed(self):
        renderer = _build_renderer(MockModelConfig())

365
        # Create a test tensor
366
367
        tensor_input = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(tensor_input)
368

369
        prompts = renderer.render_prompts(
370
            _preprocess_prompt(renderer.model_config, embed_bytes)
371
        )
372
        results = renderer.tokenize_prompts(
373
            prompts,
374
            TokenizeParams(max_total_tokens=100),
375
376
377
        )

        assert len(results) == 1
378
379
380
381
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)

    def test_multiple_prompt_embeds(self):
        renderer = _build_renderer(MockModelConfig())
382
383

        # Create multiple test tensors
384
        tensor_inputs = [
385
386
387
388
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]

389
390
        prompts = renderer.render_prompts(
            _preprocess_prompt(
391
                renderer.model_config,
392
393
                [self._create_test_embed_bytes(t) for t in tensor_inputs],
            )
394
        )
395
        results = renderer.tokenize_prompts(
396
            prompts,
397
            TokenizeParams(max_total_tokens=100),
398
399
400
401
        )

        assert len(results) == 2
        for i, result in enumerate(results):
402
403
404
405
            assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])

    def test_prompt_embed_truncation(self):
        renderer = _build_renderer(MockModelConfig())
406
407

        # Create tensor with more tokens than truncation limit
408
        tensor_input = torch.randn(20, 768, dtype=torch.float32)
409

410
411
        prompts = renderer.render_prompts(
            _preprocess_prompt(
412
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
413
            )
414
415
        )
        results = renderer.tokenize_prompts(
416
417
            prompts,
            TokenizeParams(
418
                max_total_tokens=100,
419
420
421
422
423
424
                truncate_prompt_tokens=10,
            ),
        )

        assert len(results) == 1
        # Should keep last 10 tokens
425
426
427
428
429
        expected = tensor_input[-10:]
        assert torch.equal(results[0]["prompt_embeds"], expected)

    def test_prompt_embed_different_dtypes(self):
        renderer = _build_renderer(MockModelConfig())
430
431
432
433
434

        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
435
            tensor_input = torch.randn(5, 256, dtype=dtype)
436

437
438
            prompts = renderer.render_prompts(
                _preprocess_prompt(
439
                    renderer.model_config, self._create_test_embed_bytes(tensor_input)
440
                )
441
442
            )
            results = renderer.tokenize_prompts(
443
                prompts,
444
                TokenizeParams(max_total_tokens=100),
445
446
447
448
449
            )

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

450
451
452
    def test_prompt_embed_squeeze_batch_dim(self):
        renderer = _build_renderer(MockModelConfig())

453
        # Test tensor with batch dimension gets squeezed
454
        tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
455

456
457
        prompts = renderer.render_prompts(
            _preprocess_prompt(
458
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
459
            )
460
461
        )
        results = renderer.tokenize_prompts(
462
            prompts,
463
            TokenizeParams(max_total_tokens=100),
464
465
466
467
468
469
        )

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

470
471
    def test_both_prompts_and_embeds(self):
        renderer = _build_renderer(MockModelConfig())
472

473
474
        text_input = "Hello world"
        tensor_input = torch.randn(5, 256, dtype=torch.float32)
475

476
477
        prompts = renderer.render_prompts(
            _preprocess_prompt(
478
                renderer.model_config,
479
480
                [text_input, self._create_test_embed_bytes(tensor_input)],
            )
481
        )
482
        results = renderer.tokenize_prompts(
483
            prompts,
484
            TokenizeParams(max_total_tokens=100),
485
486
487
        )

        assert len(results) == 2
488
489
490
491
492
        # First should be tokens prompt
        assert "prompt_token_ids" in results[0]
        assert len(results[0]["prompt_token_ids"]) == len(text_input)
        # Second should be embed prompt
        assert torch.equal(results[1]["prompt_embeds"], tensor_input)