test_renderer.py 5.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from typing import Optional
from unittest.mock import AsyncMock, MagicMock

import pytest

from vllm.entrypoints.renderer import CompletionRenderer


@dataclass
class MockModelConfig:
    max_model_len: int = 100
    encoder_config: Optional[dict] = None


class MockTokenizerResult:

    def __init__(self, input_ids):
        self.input_ids = input_ids


@pytest.fixture
def mock_model_config():
    return MockModelConfig()


@pytest.fixture
def mock_tokenizer():
    tokenizer = MagicMock()
    return tokenizer


@pytest.fixture
def mock_async_tokenizer():
    async_tokenizer = AsyncMock()
    return async_tokenizer


@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
    return CompletionRenderer(model_config=mock_model_config,
                              tokenizer=mock_tokenizer,
                              async_tokenizer_pool={})


class TestRenderPrompt:
    """Test Category A: Basic Functionality Tests"""

    @pytest.mark.asyncio
    async def test_token_input(self, renderer):
        tokens = [101, 7592, 2088]
        results = await renderer.render_prompt(prompt_or_prompts=tokens,
                                               max_length=100)

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == tokens

    @pytest.mark.asyncio
    async def test_token_list_input(self, renderer):
        token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
        results = await renderer.render_prompt(prompt_or_prompts=token_lists,
                                               max_length=100)

        assert len(results) == 3
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
        assert results[2]["prompt_token_ids"] == [103, 4567]

    @pytest.mark.asyncio
    async def test_text_input(self, renderer, mock_async_tokenizer):
        mock_async_tokenizer.return_value = MockTokenizerResult(
            [101, 7592, 2088])
        renderer.async_tokenizer_pool[
            renderer.tokenizer] = mock_async_tokenizer

        results = await renderer.render_prompt(prompt_or_prompts="Hello world",
                                               max_length=100)

        assert len(results) == 1
        assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
        mock_async_tokenizer.assert_called_once()

    @pytest.mark.asyncio
    async def test_text_list_input(self, renderer, mock_async_tokenizer):
        mock_async_tokenizer.return_value = MockTokenizerResult(
            [101, 7592, 2088])
        renderer.async_tokenizer_pool[
            renderer.tokenizer] = mock_async_tokenizer

        text_list_input = ["Hello world", "How are you?", "Good morning"]
        results = await renderer.render_prompt(
            prompt_or_prompts=text_list_input, max_length=100)

        assert len(results) == 3
        for result in results:
            assert result["prompt_token_ids"] == [101, 7592, 2088]
        assert mock_async_tokenizer.call_count == 3

    @pytest.mark.asyncio
    async def test_no_truncation(self, renderer, mock_async_tokenizer):
        mock_async_tokenizer.return_value = MockTokenizerResult(
            [101, 7592, 2088])
        renderer.async_tokenizer_pool[
            renderer.tokenizer] = mock_async_tokenizer

        results = await renderer.render_prompt(prompt_or_prompts="Hello world",
                                               max_length=100)

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
        assert "truncation" not in call_args.kwargs or call_args.kwargs[
            "truncation"] is False

    @pytest.mark.asyncio
    async def test_truncation_positive(self, renderer, mock_async_tokenizer):
        mock_async_tokenizer.return_value = MockTokenizerResult(
            [101, 7592, 2088])  # Truncated
        renderer.async_tokenizer_pool[
            renderer.tokenizer] = mock_async_tokenizer

        results = await renderer.render_prompt(prompt_or_prompts="Hello world",
                                               max_length=100,
                                               truncate_prompt_tokens=50)

        assert len(results) == 1
        call_args = mock_async_tokenizer.call_args
        assert call_args.kwargs["truncation"] is True
        assert call_args.kwargs["max_length"] == 50

    @pytest.mark.asyncio
    async def test_token_truncation_last_elements(self, renderer):
        # Test that token truncation keeps the last N elements
        long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
                       109]  # 10 tokens
        results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
                                               max_length=100,
                                               truncate_prompt_tokens=5)

        assert len(results) == 1
        # Should keep the last 5 tokens: [105, 106, 107, 108, 109]
        assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]

    @pytest.mark.asyncio
    async def test_max_length_exceeded(self, renderer):
        long_tokens = list(range(150))  # Exceeds max_model_len=100

        with pytest.raises(ValueError, match="maximum context length"):
            await renderer.render_prompt(prompt_or_prompts=long_tokens,
                                         max_length=100)

    @pytest.mark.asyncio
    async def test_no_tokenizer_for_text(self, mock_model_config):
        renderer_no_tokenizer = CompletionRenderer(
            model_config=mock_model_config,
            tokenizer=None,
            async_tokenizer_pool={})

        with pytest.raises(ValueError, match="No tokenizer available"):
            await renderer_no_tokenizer.render_prompt(
                prompt_or_prompts="Hello world", max_length=100)