test_serving_completions.py 6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Tests for the refactored completions serving handler
"""

from unittest.mock import AsyncMock, Mock, patch

import pytest

from sglang.srt.entrypoints.openai.protocol import (
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    CompletionStreamResponse,
    ErrorResponse,
)
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager


@pytest.fixture
def mock_tokenizer_manager():
    """Create a mock tokenizer manager"""
    manager = Mock(spec=TokenizerManager)

    # Mock tokenizer
    manager.tokenizer = Mock()
    manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
    manager.tokenizer.decode = Mock(return_value="decoded text")
    manager.tokenizer.bos_token_id = 1

    # Mock model config
    manager.model_config = Mock()
    manager.model_config.is_multimodal = False

    # Mock server args
    manager.server_args = Mock()
    manager.server_args.enable_cache_report = False

    # Mock generation
    manager.generate_request = AsyncMock()
    manager.create_abort_task = Mock(return_value=None)

    return manager


@pytest.fixture
def serving_completion(mock_tokenizer_manager):
    """Create a OpenAIServingCompletion instance"""
    return OpenAIServingCompletion(mock_tokenizer_manager)


class TestPromptHandling:
    """Test different prompt types and formats from adapter.py"""

    def test_single_string_prompt(self, serving_completion):
        """Test handling single string prompt"""
        request = CompletionRequest(
            model="test-model", prompt="Hello world", max_tokens=100
        )

        adapted_request, _ = serving_completion._convert_to_internal_request(
            [request], ["test-id"]
        )

        assert adapted_request.text == "Hello world"

    def test_single_token_ids_prompt(self, serving_completion):
        """Test handling single token IDs prompt"""
        request = CompletionRequest(
            model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
        )

        adapted_request, _ = serving_completion._convert_to_internal_request(
            [request], ["test-id"]
        )

        assert adapted_request.input_ids == [1, 2, 3, 4]

    def test_completion_template_handling(self, serving_completion):
        """Test completion template processing"""
        request = CompletionRequest(
            model="test-model",
            prompt="def hello():",
            suffix="return 'world'",
            max_tokens=100,
        )

        with patch(
            "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
            return_value=True,
        ):
            with patch(
                "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
                return_value="processed_prompt",
            ):
                adapted_request, _ = serving_completion._convert_to_internal_request(
                    [request], ["test-id"]
                )

                assert adapted_request.text == "processed_prompt"


class TestEchoHandling:
    """Test echo functionality from adapter.py"""

    def test_echo_with_string_prompt_streaming(self, serving_completion):
        """Test echo handling with string prompt in streaming"""
        request = CompletionRequest(
            model="test-model", prompt="Hello", max_tokens=100, echo=True
        )

        # Test _get_echo_text method
        echo_text = serving_completion._get_echo_text(request, 0)
        assert echo_text == "Hello"

    def test_echo_with_list_of_strings_streaming(self, serving_completion):
        """Test echo handling with list of strings in streaming"""
        request = CompletionRequest(
            model="test-model",
            prompt=["Hello", "World"],
            max_tokens=100,
            echo=True,
            n=1,
        )

        echo_text = serving_completion._get_echo_text(request, 0)
        assert echo_text == "Hello"

        echo_text = serving_completion._get_echo_text(request, 1)
        assert echo_text == "World"

    def test_echo_with_token_ids_streaming(self, serving_completion):
        """Test echo handling with token IDs in streaming"""
        request = CompletionRequest(
            model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
        )

        serving_completion.tokenizer_manager.tokenizer.decode.return_value = (
            "decoded_prompt"
        )
        echo_text = serving_completion._get_echo_text(request, 0)
        assert echo_text == "decoded_prompt"

    def test_echo_with_multiple_token_ids_streaming(self, serving_completion):
        """Test echo handling with multiple token ID prompts in streaming"""
        request = CompletionRequest(
            model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
        )

        serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
        echo_text = serving_completion._get_echo_text(request, 0)
        assert echo_text == "decoded"

    def test_prepare_echo_prompts_non_streaming(self, serving_completion):
        """Test prepare echo prompts for non-streaming response"""
        # Test with single string
        request = CompletionRequest(model="test-model", prompt="Hello", echo=True)

        echo_prompts = serving_completion._prepare_echo_prompts(request)
        assert echo_prompts == ["Hello"]

        # Test with list of strings
        request = CompletionRequest(
            model="test-model", prompt=["Hello", "World"], echo=True
        )

        echo_prompts = serving_completion._prepare_echo_prompts(request)
        assert echo_prompts == ["Hello", "World"]

        # Test with token IDs
        request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)

        serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
        echo_prompts = serving_completion._prepare_echo_prompts(request)
        assert echo_prompts == ["decoded"]