test_completions.py 15.8 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5
from collections.abc import Sequence
6
7
8
9
10
11
12
from dataclasses import dataclass
from typing import Any

import pybase64
import pytest
import torch

13
14
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
15
16
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
17
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.tokenizers.registry import tokenizer_args_from_config

MODEL_NAME = "openai-community/gpt2"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    tokenizer_revision = None
    tokenizer_mode = "auto"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False
40
    is_encoder_decoder: bool = False
41
    is_multimodal_model: bool = False
42
43


44
45
46
47
48
@dataclass
class MockVllmConfig:
    model_config: MockModelConfig


49
50
51
52
53
54
55
56
57
58
59
60
61
@dataclass
class DummyTokenizer:
    truncation_side: str = "left"
    max_chars_per_token: int = 1

    def __post_init__(self) -> None:
        self._captured_encode_kwargs: dict = {}

    def decode(self, tokens: list[int]):
        return str(tokens)

    def encode(self, text: str, **kwargs):
        self._captured_encode_kwargs = kwargs
62

63
64
65
66
67
        in_length = len(text)
        truncation = kwargs.get("truncation")
        max_length = kwargs.get("max_length")
        if truncation and max_length is not None:
            return list(range(min(in_length, max_length)))
68

69
        return list(range(in_length))
70
71


72
73
74
75
76
77
78
def _build_renderer(
    model_config: MockModelConfig,
    *,
    truncation_side: str = "left",
    max_chars_per_token: int = 1,
):
    _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
79

80
    renderer = HfRenderer(
81
        MockVllmConfig(model_config),
82
83
84
85
86
87
88
89
        tokenizer=(
            None
            if model_config.skip_tokenizer_init
            else DummyTokenizer(
                truncation_side=truncation_side,
                max_chars_per_token=max_chars_per_token,
            )
        ),
90
91
    )

92
93
    return renderer

94

95
96
97
98
99
100
101
102
103
104
105
def _preprocess_prompt(
    mdoel_config: ModelConfig,
    prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
    return [
        (
            prompt
            if isinstance(prompt, bytes)
            else parse_model_prompt(mdoel_config, prompt)
        )
        for prompt in prompt_to_seq(prompt_or_prompts)
106
107
108
    ]


109
class TestValidatePrompt:
110
111
112
    def test_empty_input(self):
        renderer = _build_renderer(MockModelConfig())

113
        with pytest.raises(ValueError, match="at least one prompt"):
114
            renderer.render_prompts(_preprocess_prompt(renderer.model_config, []))
115

116
117
118
    def test_invalid_type(self):
        renderer = _build_renderer(MockModelConfig())

119
120
        with pytest.raises(TypeError, match="should be a list of integers"):
            renderer.render_prompts(
121
                _preprocess_prompt(renderer.model_config, [[1, 2], ["foo", "bar"]])  # type: ignore[arg-type]
122
            )
123
124
125


class TestRenderPrompt:
126
127
128
    def test_token_input(self):
        renderer = _build_renderer(MockModelConfig())

129
        tokens = [101, 7592, 2088]
130
131
132
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
133
        results = renderer.tokenize_prompts(
134
135
136
137
138
139
140
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

141
142
143
    def test_token_list_input(self):
        renderer = _build_renderer(MockModelConfig())

144
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
145
        prompts = renderer.render_prompts(
146
            _preprocess_prompt(renderer.model_config, token_lists)
147
        )
148
        results = renderer.tokenize_prompts(
149
150
151
152
153
154
155
156
157
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

158
159
    def test_text_input(self):
        renderer = _build_renderer(MockModelConfig())
160

161
        text_input = "x" * 10
162
        prompts = renderer.render_prompts(
163
            _preprocess_prompt(renderer.model_config, text_input)
164
        )
165
        results = renderer.tokenize_prompts(
166
167
168
169
170
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
171
        assert len(results[0]["prompt_token_ids"]) == 10
172

173
174
    def test_text_list_input(self):
        renderer = _build_renderer(MockModelConfig())
175

176
        text_list_input = ["x" * 10, "x" * 12, "x" * 14]
177
        prompts = renderer.render_prompts(
178
            _preprocess_prompt(renderer.model_config, text_list_input)
179
        )
180
        results = renderer.tokenize_prompts(
181
182
183
184
185
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
186
187
        for text_input, result in zip(text_list_input, results):
            assert len(result["prompt_token_ids"]) == len(text_input)
188

189
190
    def test_zero_truncation(self):
        renderer = _build_renderer(MockModelConfig())
191

192
        prompts = renderer.render_prompts(
193
            _preprocess_prompt(renderer.model_config, "x" * 200)
194
        )
195
        results = renderer.tokenize_prompts(
196
            prompts,
197
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
198
199
200
        )

        assert len(results) == 1
201
        assert len(results[0]["prompt_token_ids"]) == 0
202

203
204
    def test_pos_truncation(self):
        renderer = _build_renderer(MockModelConfig())
205

206
        prompts = renderer.render_prompts(
207
            _preprocess_prompt(renderer.model_config, "x" * 200)
208
        )
209
        results = renderer.tokenize_prompts(
210
            prompts,
211
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
212
213
214
        )

        assert len(results) == 1
215
216
217
218
219
        assert len(results[0]["prompt_token_ids"]) == 50

    def test_neg_truncation(self):
        renderer = _build_renderer(MockModelConfig())

220
        prompts = renderer.render_prompts(
221
            _preprocess_prompt(renderer.model_config, "x" * 200)
222
        )
223
        results = renderer.tokenize_prompts(
224
            prompts,
225
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
226
227
228
        )

        assert len(results) == 1
229
230
231
232
        assert len(results[0]["prompt_token_ids"]) == 100  # max_total_tokens

    def test_truncation_left(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="left")
233
234

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
235
        prompts = renderer.render_prompts(
236
            _preprocess_prompt(renderer.model_config, long_tokens)
237
        )
238
        results = renderer.tokenize_prompts(
239
            prompts,
240
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
241
242
243
244
245
246
        )

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

247
248
249
250
    def test_truncation_right(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="right")

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
251
        prompts = renderer.render_prompts(
252
            _preprocess_prompt(renderer.model_config, long_tokens)
253
        )
254
255
256
257
        results = renderer.tokenize_prompts(
            prompts,
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
        )
258

259
260
261
        assert len(results) == 1
        # Should keep the first 5 tokens: [100, 101, 102, 103, 104]
        assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
262

263
264
265
266
267
    def test_text_max_length_exceeded_obvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)

        # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
268
        prompts = renderer.render_prompts(
269
            _preprocess_prompt(renderer.model_config, long_tokens)
270
        )
271
272
273
274
275
276

        with pytest.raises(
            ValueError,
            match="input characters and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
277
278
279
280
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

281
        # Should not even attempt tokenization
282
        assert renderer.tokenizer._captured_encode_kwargs == {}
283
284
285
286
287
288

    def test_text_max_length_exceeded_nonobvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)

        # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
289
        prompts = renderer.render_prompts(
290
            _preprocess_prompt(renderer.model_config, long_tokens)
291
        )
292
293
294
295
296
297
298
299
300
301
302

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

        # Should only tokenize the first max_total_tokens + 1 tokens
303
304
        assert renderer.tokenizer._captured_encode_kwargs["truncation"] is True
        assert renderer.tokenizer._captured_encode_kwargs["max_length"] == 101
305
306
307
308
309

    def test_token_max_length_exceeded(self):
        renderer = _build_renderer(MockModelConfig())

        long_tokens = list(range(150))  # Exceeds max_total_tokens=100
310
        prompts = renderer.render_prompts(
311
            _preprocess_prompt(renderer.model_config, long_tokens)
312
        )
313
314
315
316
317
318
319
320
321
322
323
324

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
            )

    def test_no_tokenizer_for_text(self):
        renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
325

326
        prompts = renderer.render_prompts(
327
            _preprocess_prompt(renderer.model_config, "Hello world")
328
        )
329
330

        with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
331
            renderer.tokenize_prompts(
332
333
334
335
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

336
337
    def test_token_input_with_needs_detokenization(self):
        renderer = _build_renderer(MockModelConfig())
338
339

        tokens = [1, 2, 3, 4]
340
341
342
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.model_config, tokens)
        )
343
        results = renderer.tokenize_prompts(
344
345
            prompts,
            TokenizeParams(
346
                max_total_tokens=100,
347
348
349
350
351
352
                needs_detokenization=True,
            ),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
353
        assert results[0]["prompt"] == "[1, 2, 3, 4]"
354
355
356
357
358
359
360
361
362
363


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

364
365
366
    def test_single_prompt_embed(self):
        renderer = _build_renderer(MockModelConfig())

367
        # Create a test tensor
368
369
        tensor_input = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(tensor_input)
370

371
        prompts = renderer.render_prompts(
372
            _preprocess_prompt(renderer.model_config, embed_bytes)
373
        )
374
        results = renderer.tokenize_prompts(
375
            prompts,
376
            TokenizeParams(max_total_tokens=100),
377
378
379
        )

        assert len(results) == 1
380
381
382
383
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)

    def test_multiple_prompt_embeds(self):
        renderer = _build_renderer(MockModelConfig())
384
385

        # Create multiple test tensors
386
        tensor_inputs = [
387
388
389
390
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]

391
392
        prompts = renderer.render_prompts(
            _preprocess_prompt(
393
                renderer.model_config,
394
395
                [self._create_test_embed_bytes(t) for t in tensor_inputs],
            )
396
        )
397
        results = renderer.tokenize_prompts(
398
            prompts,
399
            TokenizeParams(max_total_tokens=100),
400
401
402
403
        )

        assert len(results) == 2
        for i, result in enumerate(results):
404
405
406
407
            assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])

    def test_prompt_embed_truncation(self):
        renderer = _build_renderer(MockModelConfig())
408
409

        # Create tensor with more tokens than truncation limit
410
        tensor_input = torch.randn(20, 768, dtype=torch.float32)
411

412
413
        prompts = renderer.render_prompts(
            _preprocess_prompt(
414
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
415
            )
416
417
        )
        results = renderer.tokenize_prompts(
418
419
            prompts,
            TokenizeParams(
420
                max_total_tokens=100,
421
422
423
424
425
426
                truncate_prompt_tokens=10,
            ),
        )

        assert len(results) == 1
        # Should keep last 10 tokens
427
428
429
430
431
        expected = tensor_input[-10:]
        assert torch.equal(results[0]["prompt_embeds"], expected)

    def test_prompt_embed_different_dtypes(self):
        renderer = _build_renderer(MockModelConfig())
432
433
434
435
436

        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
437
            tensor_input = torch.randn(5, 256, dtype=dtype)
438

439
440
            prompts = renderer.render_prompts(
                _preprocess_prompt(
441
                    renderer.model_config, self._create_test_embed_bytes(tensor_input)
442
                )
443
444
            )
            results = renderer.tokenize_prompts(
445
                prompts,
446
                TokenizeParams(max_total_tokens=100),
447
448
449
450
451
            )

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

452
453
454
    def test_prompt_embed_squeeze_batch_dim(self):
        renderer = _build_renderer(MockModelConfig())

455
        # Test tensor with batch dimension gets squeezed
456
        tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
457

458
459
        prompts = renderer.render_prompts(
            _preprocess_prompt(
460
                renderer.model_config, self._create_test_embed_bytes(tensor_input)
461
            )
462
463
        )
        results = renderer.tokenize_prompts(
464
            prompts,
465
            TokenizeParams(max_total_tokens=100),
466
467
468
469
470
471
        )

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

472
473
    def test_both_prompts_and_embeds(self):
        renderer = _build_renderer(MockModelConfig())
474

475
476
        text_input = "Hello world"
        tensor_input = torch.randn(5, 256, dtype=torch.float32)
477

478
479
        prompts = renderer.render_prompts(
            _preprocess_prompt(
480
                renderer.model_config,
481
482
                [text_input, self._create_test_embed_bytes(tensor_input)],
            )
483
        )
484
        results = renderer.tokenize_prompts(
485
            prompts,
486
            TokenizeParams(max_total_tokens=100),
487
488
489
        )

        assert len(results) == 2
490
491
492
493
494
        # First should be tokens prompt
        assert "prompt_token_ids" in results[0]
        assert len(results[0]["prompt_token_ids"]) == len(text_input)
        # Second should be embed prompt
        assert torch.equal(results[1]["prompt_embeds"], tensor_input)