test_completions.py 15.5 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
5
from collections.abc import Sequence
6
7
8
9
10
11
12
from dataclasses import dataclass
from typing import Any

import pybase64
import pytest
import torch

13
14
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
15
16
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
17
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from vllm.tokenizers.registry import tokenizer_args_from_config

MODEL_NAME = "openai-community/gpt2"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    tokenizer_revision = None
    tokenizer_mode = "auto"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False
40
    is_encoder_decoder: bool = False
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
@dataclass
class DummyTokenizer:
    truncation_side: str = "left"
    max_chars_per_token: int = 1

    def __post_init__(self) -> None:
        self._captured_encode_kwargs: dict = {}

    def decode(self, tokens: list[int]):
        return str(tokens)

    def encode(self, text: str, **kwargs):
        self._captured_encode_kwargs = kwargs
56

57
58
59
60
61
        in_length = len(text)
        truncation = kwargs.get("truncation")
        max_length = kwargs.get("max_length")
        if truncation and max_length is not None:
            return list(range(min(in_length, max_length)))
62

63
        return list(range(in_length))
64
65


66
67
68
69
70
71
72
def _build_renderer(
    model_config: MockModelConfig,
    *,
    truncation_side: str = "left",
    max_chars_per_token: int = 1,
):
    _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
73

74
75
    renderer = HfRenderer(
        model_config,
76
77
78
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )

79
80
81
82
83
84
85
86
    if not model_config.skip_tokenizer_init:
        renderer._tokenizer = DummyTokenizer(
            truncation_side=truncation_side,
            max_chars_per_token=max_chars_per_token,
        )

    return renderer

87

88
89
90
91
92
93
94
95
96
97
98
def _preprocess_prompt(
    mdoel_config: ModelConfig,
    prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
    return [
        (
            prompt
            if isinstance(prompt, bytes)
            else parse_model_prompt(mdoel_config, prompt)
        )
        for prompt in prompt_to_seq(prompt_or_prompts)
99
100
101
    ]


102
class TestValidatePrompt:
103
104
105
    def test_empty_input(self):
        renderer = _build_renderer(MockModelConfig())

106
        with pytest.raises(ValueError, match="at least one prompt"):
107
            renderer.render_prompts(_preprocess_prompt(renderer.config, []))
108

109
110
111
    def test_invalid_type(self):
        renderer = _build_renderer(MockModelConfig())

112
113
114
115
        with pytest.raises(TypeError, match="should be a list of integers"):
            renderer.render_prompts(
                _preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]])  # type: ignore[arg-type]
            )
116
117
118


class TestRenderPrompt:
119
120
121
    def test_token_input(self):
        renderer = _build_renderer(MockModelConfig())

122
        tokens = [101, 7592, 2088]
123
        prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
124
        results = renderer.tokenize_prompts(
125
126
127
128
129
130
131
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

132
133
134
    def test_token_list_input(self):
        renderer = _build_renderer(MockModelConfig())

135
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
136
137
138
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, token_lists)
        )
139
        results = renderer.tokenize_prompts(
140
141
142
143
144
145
146
147
148
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

149
150
    def test_text_input(self):
        renderer = _build_renderer(MockModelConfig())
151

152
        text_input = "x" * 10
153
154
155
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, text_input)
        )
156
        results = renderer.tokenize_prompts(
157
158
159
160
161
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
162
        assert len(results[0]["prompt_token_ids"]) == 10
163

164
165
    def test_text_list_input(self):
        renderer = _build_renderer(MockModelConfig())
166

167
        text_list_input = ["x" * 10, "x" * 12, "x" * 14]
168
169
170
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, text_list_input)
        )
171
        results = renderer.tokenize_prompts(
172
173
174
175
176
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
177
178
        for text_input, result in zip(text_list_input, results):
            assert len(result["prompt_token_ids"]) == len(text_input)
179

180
181
    def test_zero_truncation(self):
        renderer = _build_renderer(MockModelConfig())
182

183
184
185
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, "x" * 200)
        )
186
        results = renderer.tokenize_prompts(
187
            prompts,
188
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
189
190
191
        )

        assert len(results) == 1
192
        assert len(results[0]["prompt_token_ids"]) == 0
193

194
195
    def test_pos_truncation(self):
        renderer = _build_renderer(MockModelConfig())
196

197
198
199
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, "x" * 200)
        )
200
        results = renderer.tokenize_prompts(
201
            prompts,
202
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
203
204
205
        )

        assert len(results) == 1
206
207
208
209
210
        assert len(results[0]["prompt_token_ids"]) == 50

    def test_neg_truncation(self):
        renderer = _build_renderer(MockModelConfig())

211
212
213
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, "x" * 200)
        )
214
        results = renderer.tokenize_prompts(
215
            prompts,
216
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
217
218
219
        )

        assert len(results) == 1
220
221
222
223
        assert len(results[0]["prompt_token_ids"]) == 100  # max_total_tokens

    def test_truncation_left(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="left")
224
225

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
226
227
228
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, long_tokens)
        )
229
        results = renderer.tokenize_prompts(
230
            prompts,
231
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
232
233
234
235
236
237
        )

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

238
239
240
241
    def test_truncation_right(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="right")

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
242
243
244
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, long_tokens)
        )
245
246
247
248
        results = renderer.tokenize_prompts(
            prompts,
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
        )
249

250
251
252
        assert len(results) == 1
        # Should keep the first 5 tokens: [100, 101, 102, 103, 104]
        assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
253

254
255
256
257
258
    def test_text_max_length_exceeded_obvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)

        # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
259
260
261
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, long_tokens)
        )
262
263
264
265
266
267

        with pytest.raises(
            ValueError,
            match="input characters and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
268
269
270
271
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

272
273
274
275
276
277
278
279
        # Should not even attempt tokenization
        assert renderer._tokenizer._captured_encode_kwargs == {}

    def test_text_max_length_exceeded_nonobvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)

        # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
280
281
282
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, long_tokens)
        )
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

        # Should only tokenize the first max_total_tokens + 1 tokens
        assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
        assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101

    def test_token_max_length_exceeded(self):
        renderer = _build_renderer(MockModelConfig())

        long_tokens = list(range(150))  # Exceeds max_total_tokens=100
301
302
303
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, long_tokens)
        )
304
305
306
307
308
309
310
311
312
313
314
315

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
            )

    def test_no_tokenizer_for_text(self):
        renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
316

317
318
319
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, "Hello world")
        )
320
321

        with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
322
            renderer.tokenize_prompts(
323
324
325
326
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

327
328
    def test_token_input_with_needs_detokenization(self):
        renderer = _build_renderer(MockModelConfig())
329
330

        tokens = [1, 2, 3, 4]
331
        prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
332
        results = renderer.tokenize_prompts(
333
334
            prompts,
            TokenizeParams(
335
                max_total_tokens=100,
336
337
338
339
340
341
                needs_detokenization=True,
            ),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
342
        assert results[0]["prompt"] == "[1, 2, 3, 4]"
343
344
345
346
347
348
349
350
351
352


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

353
354
355
    def test_single_prompt_embed(self):
        renderer = _build_renderer(MockModelConfig())

356
        # Create a test tensor
357
358
        tensor_input = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(tensor_input)
359

360
361
362
        prompts = renderer.render_prompts(
            _preprocess_prompt(renderer.config, embed_bytes)
        )
363
        results = renderer.tokenize_prompts(
364
            prompts,
365
            TokenizeParams(max_total_tokens=100),
366
367
368
        )

        assert len(results) == 1
369
370
371
372
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)

    def test_multiple_prompt_embeds(self):
        renderer = _build_renderer(MockModelConfig())
373
374

        # Create multiple test tensors
375
        tensor_inputs = [
376
377
378
379
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]

380
381
382
383
384
        prompts = renderer.render_prompts(
            _preprocess_prompt(
                renderer.config,
                [self._create_test_embed_bytes(t) for t in tensor_inputs],
            )
385
        )
386
        results = renderer.tokenize_prompts(
387
            prompts,
388
            TokenizeParams(max_total_tokens=100),
389
390
391
392
        )

        assert len(results) == 2
        for i, result in enumerate(results):
393
394
395
396
            assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])

    def test_prompt_embed_truncation(self):
        renderer = _build_renderer(MockModelConfig())
397
398

        # Create tensor with more tokens than truncation limit
399
        tensor_input = torch.randn(20, 768, dtype=torch.float32)
400

401
402
403
404
        prompts = renderer.render_prompts(
            _preprocess_prompt(
                renderer.config, self._create_test_embed_bytes(tensor_input)
            )
405
406
        )
        results = renderer.tokenize_prompts(
407
408
            prompts,
            TokenizeParams(
409
                max_total_tokens=100,
410
411
412
413
414
415
                truncate_prompt_tokens=10,
            ),
        )

        assert len(results) == 1
        # Should keep last 10 tokens
416
417
418
419
420
        expected = tensor_input[-10:]
        assert torch.equal(results[0]["prompt_embeds"], expected)

    def test_prompt_embed_different_dtypes(self):
        renderer = _build_renderer(MockModelConfig())
421
422
423
424
425

        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
426
            tensor_input = torch.randn(5, 256, dtype=dtype)
427

428
429
430
431
            prompts = renderer.render_prompts(
                _preprocess_prompt(
                    renderer.config, self._create_test_embed_bytes(tensor_input)
                )
432
433
            )
            results = renderer.tokenize_prompts(
434
                prompts,
435
                TokenizeParams(max_total_tokens=100),
436
437
438
439
440
            )

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

441
442
443
    def test_prompt_embed_squeeze_batch_dim(self):
        renderer = _build_renderer(MockModelConfig())

444
        # Test tensor with batch dimension gets squeezed
445
        tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
446

447
448
449
450
        prompts = renderer.render_prompts(
            _preprocess_prompt(
                renderer.config, self._create_test_embed_bytes(tensor_input)
            )
451
452
        )
        results = renderer.tokenize_prompts(
453
            prompts,
454
            TokenizeParams(max_total_tokens=100),
455
456
457
458
459
460
        )

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

461
462
    def test_both_prompts_and_embeds(self):
        renderer = _build_renderer(MockModelConfig())
463

464
465
        text_input = "Hello world"
        tensor_input = torch.randn(5, 256, dtype=torch.float32)
466

467
468
469
470
471
        prompts = renderer.render_prompts(
            _preprocess_prompt(
                renderer.config,
                [text_input, self._create_test_embed_bytes(tensor_input)],
            )
472
        )
473
        results = renderer.tokenize_prompts(
474
            prompts,
475
            TokenizeParams(max_total_tokens=100),
476
477
478
        )

        assert len(results) == 2
479
480
481
482
483
        # First should be tokens prompt
        assert "prompt_token_ids" in results[0]
        assert len(results[0]["prompt_token_ids"]) == len(text_input)
        # Second should be embed prompt
        assert torch.equal(results[1]["prompt_embeds"], tensor_input)