test_completions.py 15.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import io
from dataclasses import dataclass
from typing import Any

import pybase64
import pytest
import torch

from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config

MODEL_NAME = "openai-community/gpt2"


@dataclass
class MockHFConfig:
    model_type: str = "any"


@dataclass
class MockModelConfig:
    runner_type = "generate"
    model: str = MODEL_NAME
    tokenizer: str = MODEL_NAME
    trust_remote_code: bool = False
    tokenizer_revision = None
    tokenizer_mode = "auto"
    hf_config = MockHFConfig()
    encoder_config: dict[str, Any] | None = None
    enable_prompt_embeds: bool = True
    skip_tokenizer_init: bool = False


38
39
40
41
42
43
44
45
46
47
48
49
50
@dataclass
class DummyTokenizer:
    truncation_side: str = "left"
    max_chars_per_token: int = 1

    def __post_init__(self) -> None:
        self._captured_encode_kwargs: dict = {}

    def decode(self, tokens: list[int]):
        return str(tokens)

    def encode(self, text: str, **kwargs):
        self._captured_encode_kwargs = kwargs
51

52
53
54
55
56
        in_length = len(text)
        truncation = kwargs.get("truncation")
        max_length = kwargs.get("max_length")
        if truncation and max_length is not None:
            return list(range(min(in_length, max_length)))
57

58
        return list(range(in_length))
59
60


61
62
63
64
65
66
67
def _build_renderer(
    model_config: MockModelConfig,
    *,
    truncation_side: str = "left",
    max_chars_per_token: int = 1,
):
    _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
68

69
70
    renderer = HfRenderer(
        model_config,
71
72
73
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )

74
75
76
77
78
79
80
81
    if not model_config.skip_tokenizer_init:
        renderer._tokenizer = DummyTokenizer(
            truncation_side=truncation_side,
            max_chars_per_token=max_chars_per_token,
        )

    return renderer

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

class TestValidatePrompt:
    STRING_INPUTS = [
        "",
        "foo",
        "foo bar",
        "foo baz bar",
        "foo bar qux baz",
    ]

    TOKEN_INPUTS = [
        [-1],
        [1],
        [1, 2],
        [1, 3, 4],
        [1, 2, 4, 3],
    ]

    INPUTS_SLICES = [
        slice(None, None, -1),
        slice(None, None, 2),
        slice(None, None, -2),
    ]

    # Test that a nested mixed-type list of lists raises a TypeError.
107
108
109
    def test_empty_input(self):
        renderer = _build_renderer(MockModelConfig())

110
111
112
        with pytest.raises(ValueError, match="at least one prompt"):
            renderer.render_completions([])

113
114
115
    def test_invalid_type(self):
        renderer = _build_renderer(MockModelConfig())

116
117
118
119
        with pytest.raises(TypeError, match="string or an array of tokens"):
            renderer.render_completions([[1, 2], ["foo", "bar"]])

    @pytest.mark.parametrize("string_input", STRING_INPUTS)
120
121
122
    def test_string_consistent(self, string_input: str):
        renderer = _build_renderer(MockModelConfig())

123
124
125
126
127
        assert renderer.render_completions(string_input) == renderer.render_completions(
            [string_input]
        )

    @pytest.mark.parametrize("token_input", TOKEN_INPUTS)
128
129
130
    def test_token_consistent(self, token_input: list[int]):
        renderer = _build_renderer(MockModelConfig())

131
132
133
134
135
        assert renderer.render_completions(token_input) == renderer.render_completions(
            [token_input]
        )

    @pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
136
137
138
    def test_string_slice(self, inputs_slice: slice):
        renderer = _build_renderer(MockModelConfig())

139
140
141
142
143
144
        assert renderer.render_completions(self.STRING_INPUTS)[
            inputs_slice
        ] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])


class TestRenderPrompt:
145
146
147
    def test_token_input(self):
        renderer = _build_renderer(MockModelConfig())

148
        tokens = [101, 7592, 2088]
149
150
        prompts = renderer.render_completions(tokens)
        results = renderer.tokenize_prompts(
151
152
153
154
155
156
157
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

158
159
160
    def test_token_list_input(self):
        renderer = _build_renderer(MockModelConfig())

161
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
162
163
        prompts = renderer.render_completions(token_lists)
        results = renderer.tokenize_prompts(
164
165
166
167
168
169
170
171
172
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

173
174
    def test_text_input(self):
        renderer = _build_renderer(MockModelConfig())
175

176
177
178
        text_input = "x" * 10
        prompts = renderer.render_completions(text_input)
        results = renderer.tokenize_prompts(
179
180
181
182
183
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 1
184
        assert len(results[0]["prompt_token_ids"]) == 10
185

186
187
    def test_text_list_input(self):
        renderer = _build_renderer(MockModelConfig())
188

189
190
191
        text_list_input = ["x" * 10, "x" * 12, "x" * 14]
        prompts = renderer.render_completions(text_list_input)
        results = renderer.tokenize_prompts(
192
193
194
195
196
            prompts,
            TokenizeParams(max_total_tokens=100),
        )

        assert len(results) == 3
197
198
        for text_input, result in zip(text_list_input, results):
            assert len(result["prompt_token_ids"]) == len(text_input)
199

200
201
    def test_zero_truncation(self):
        renderer = _build_renderer(MockModelConfig())
202

203
204
        prompts = renderer.render_completions("x" * 200)
        results = renderer.tokenize_prompts(
205
            prompts,
206
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
207
208
209
        )

        assert len(results) == 1
210
        assert len(results[0]["prompt_token_ids"]) == 0
211

212
213
    def test_pos_truncation(self):
        renderer = _build_renderer(MockModelConfig())
214

215
216
        prompts = renderer.render_completions("x" * 200)
        results = renderer.tokenize_prompts(
217
            prompts,
218
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
219
220
221
        )

        assert len(results) == 1
222
223
224
225
226
227
228
        assert len(results[0]["prompt_token_ids"]) == 50

    def test_neg_truncation(self):
        renderer = _build_renderer(MockModelConfig())

        prompts = renderer.render_completions("x" * 200)
        results = renderer.tokenize_prompts(
229
            prompts,
230
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
231
232
233
        )

        assert len(results) == 1
234
235
236
237
        assert len(results[0]["prompt_token_ids"]) == 100  # max_total_tokens

    def test_truncation_left(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="left")
238
239

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
240
241
        prompts = renderer.render_completions(long_tokens)
        results = renderer.tokenize_prompts(
242
            prompts,
243
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
244
245
246
247
248
249
        )

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

250
251
252
253
254
255
256
257
258
    def test_truncation_right(self):
        renderer = _build_renderer(MockModelConfig(), truncation_side="right")

        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]  # 10 tokens
        prompts = renderer.render_completions(long_tokens)
        results = renderer.tokenize_prompts(
            prompts,
            TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
        )
259

260
261
262
        assert len(results) == 1
        # Should keep the first 5 tokens: [100, 101, 102, 103, 104]
        assert results[0]["prompt_token_ids"] == [100, 101, 102, 103, 104]
263

264
265
266
267
268
269
270
271
272
273
274
275
    def test_text_max_length_exceeded_obvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=1)

        # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
        prompts = renderer.render_completions(long_tokens)

        with pytest.raises(
            ValueError,
            match="input characters and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
276
277
278
279
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        # Should not even attempt tokenization
        assert renderer._tokenizer._captured_encode_kwargs == {}

    def test_text_max_length_exceeded_nonobvious(self):
        renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)

        # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
        long_tokens = "x" * 150
        prompts = renderer.render_completions(long_tokens)

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

        # Should only tokenize the first max_total_tokens + 1 tokens
        assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True
        assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101

    def test_token_max_length_exceeded(self):
        renderer = _build_renderer(MockModelConfig())

        long_tokens = list(range(150))  # Exceeds max_total_tokens=100
        prompts = renderer.render_completions(long_tokens)

        with pytest.raises(
            ValueError,
            match="input tokens and requested .* context length is only",
        ):
            renderer.tokenize_prompts(
                prompts,
                TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=None),
            )

    def test_no_tokenizer_for_text(self):
        renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
320

321
        prompts = renderer.render_completions("Hello world")
322
323

        with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
324
            renderer.tokenize_prompts(
325
326
327
328
                prompts,
                TokenizeParams(max_total_tokens=100),
            )

329
330
    def test_token_input_with_needs_detokenization(self):
        renderer = _build_renderer(MockModelConfig())
331
332

        tokens = [1, 2, 3, 4]
333
334
        prompts = renderer.render_completions(tokens)
        results = renderer.tokenize_prompts(
335
336
            prompts,
            TokenizeParams(
337
                max_total_tokens=100,
338
339
340
341
342
343
                needs_detokenization=True,
            ),
        )

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens
344
        assert results[0]["prompt"] == "[1, 2, 3, 4]"
345
346
347
348
349
350
351
352
353
354


class TestRenderEmbedPrompt:
    def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
        """Helper to create base64-encoded tensor bytes"""
        buffer = io.BytesIO()
        torch.save(tensor, buffer)
        buffer.seek(0)
        return pybase64.b64encode(buffer.read())

355
356
357
    def test_single_prompt_embed(self):
        renderer = _build_renderer(MockModelConfig())

358
        # Create a test tensor
359
360
        tensor_input = torch.randn(10, 768, dtype=torch.float32)
        embed_bytes = self._create_test_embed_bytes(tensor_input)
361

362
363
        prompts = renderer.render_completions(prompt_embeds=embed_bytes)
        results = renderer.tokenize_prompts(
364
            prompts,
365
            TokenizeParams(max_total_tokens=100),
366
367
368
        )

        assert len(results) == 1
369
370
371
372
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)

    def test_multiple_prompt_embeds(self):
        renderer = _build_renderer(MockModelConfig())
373
374

        # Create multiple test tensors
375
        tensor_inputs = [
376
377
378
379
            torch.randn(8, 512, dtype=torch.float32),
            torch.randn(12, 512, dtype=torch.float32),
        ]

380
381
        prompts = renderer.render_completions(
            prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
382
        )
383
        results = renderer.tokenize_prompts(
384
            prompts,
385
            TokenizeParams(max_total_tokens=100),
386
387
388
389
        )

        assert len(results) == 2
        for i, result in enumerate(results):
390
391
392
393
            assert torch.allclose(result["prompt_embeds"], tensor_inputs[i])

    def test_prompt_embed_truncation(self):
        renderer = _build_renderer(MockModelConfig())
394
395

        # Create tensor with more tokens than truncation limit
396
        tensor_input = torch.randn(20, 768, dtype=torch.float32)
397

398
399
400
401
        prompts = renderer.render_completions(
            prompt_embeds=self._create_test_embed_bytes(tensor_input),
        )
        results = renderer.tokenize_prompts(
402
403
            prompts,
            TokenizeParams(
404
                max_total_tokens=100,
405
406
407
408
409
410
                truncate_prompt_tokens=10,
            ),
        )

        assert len(results) == 1
        # Should keep last 10 tokens
411
412
413
414
415
        expected = tensor_input[-10:]
        assert torch.equal(results[0]["prompt_embeds"], expected)

    def test_prompt_embed_different_dtypes(self):
        renderer = _build_renderer(MockModelConfig())
416
417
418
419
420

        # Test different supported dtypes
        dtypes = [torch.float32, torch.float16, torch.bfloat16]

        for dtype in dtypes:
421
            tensor_input = torch.randn(5, 256, dtype=dtype)
422

423
424
425
426
            prompts = renderer.render_completions(
                prompt_embeds=self._create_test_embed_bytes(tensor_input),
            )
            results = renderer.tokenize_prompts(
427
                prompts,
428
                TokenizeParams(max_total_tokens=100),
429
430
431
432
433
            )

            assert len(results) == 1
            assert results[0]["prompt_embeds"].dtype == dtype

434
435
436
    def test_prompt_embed_squeeze_batch_dim(self):
        renderer = _build_renderer(MockModelConfig())

437
        # Test tensor with batch dimension gets squeezed
438
        tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
439

440
441
442
443
        prompts = renderer.render_completions(
            prompt_embeds=self._create_test_embed_bytes(tensor_input),
        )
        results = renderer.tokenize_prompts(
444
            prompts,
445
            TokenizeParams(max_total_tokens=100),
446
447
448
449
450
451
        )

        assert len(results) == 1
        # Should be squeezed to 2D
        assert results[0]["prompt_embeds"].shape == (10, 768)

452
453
    def test_both_prompts_and_embeds(self):
        renderer = _build_renderer(MockModelConfig())
454

455
456
        text_input = "Hello world"
        tensor_input = torch.randn(5, 256, dtype=torch.float32)
457

458
459
460
        prompts = renderer.render_completions(
            text_input,
            prompt_embeds=self._create_test_embed_bytes(tensor_input),
461
        )
462
        results = renderer.tokenize_prompts(
463
            prompts,
464
            TokenizeParams(max_total_tokens=100),
465
466
467
468
        )

        assert len(results) == 2
        # First should be embed prompt
469
        assert torch.equal(results[0]["prompt_embeds"], tensor_input)
470
471
        # Second should be tokens prompt
        assert "prompt_token_ids" in results[1]
472
        assert len(results[1]["prompt_token_ids"]) == len(text_input)