test_serving_embedding.py 9.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.

These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""

import asyncio
import json
import time
11
import unittest
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch

from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError

from sglang.srt.entrypoints.openai.protocol import (
    EmbeddingObject,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorResponse,
    MultimodalEmbeddingInput,
    UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput


# Mock TokenizerManager for embedding tests
33
class _MockTokenizerManager:
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def __init__(self):
        self.model_config = Mock()
        self.model_config.is_multimodal = False
        self.server_args = Mock()
        self.server_args.enable_cache_report = False
        self.model_path = "test-model"

        # Mock tokenizer
        self.tokenizer = Mock()
        self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
        self.tokenizer.decode = Mock(return_value="Test embedding input")
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

        # Mock generate_request method for embeddings
        async def mock_generate_embedding():
            yield {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20,  # 100-dim embedding
                "meta_info": {
                    "id": f"embd-{uuid.uuid4()}",
                    "prompt_tokens": 5,
                },
            }

        self.generate_request = Mock(return_value=mock_generate_embedding())


61
62
63
64
65
class ServingEmbeddingTestCase(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures."""
        self.tokenizer_manager = _MockTokenizerManager()
        self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
66

67
68
        self.request = Mock(spec=Request)
        self.request.headers = {}
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        self.basic_req = EmbeddingRequest(
            model="test-model",
            input="Hello, how are you?",
            encoding_format="float",
        )
        self.list_req = EmbeddingRequest(
            model="test-model",
            input=["Hello, how are you?", "I am fine, thank you!"],
            encoding_format="float",
        )
        self.multimodal_req = EmbeddingRequest(
            model="test-model",
            input=[
                MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
                MultimodalEmbeddingInput(text="World", image=None),
            ],
            encoding_format="float",
        )
        self.token_ids_req = EmbeddingRequest(
            model="test-model",
            input=[1, 2, 3, 4, 5],
            encoding_format="float",
        )
93

94
    def test_convert_single_string_request(self):
95
96
        """Test converting single string request to internal format."""
        adapted_request, processed_request = (
97
98
            self.serving_embedding._convert_to_internal_request(
                [self.basic_req], ["test-id"]
99
100
101
            )
        )

102
103
104
105
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.text, "Hello, how are you?")
        self.assertEqual(adapted_request.rid, "test-id")
        self.assertEqual(processed_request, self.basic_req)
106

107
    def test_convert_list_string_request(self):
108
109
        """Test converting list of strings request to internal format."""
        adapted_request, processed_request = (
110
111
            self.serving_embedding._convert_to_internal_request(
                [self.list_req], ["test-id"]
112
113
114
            )
        )

115
116
117
118
119
120
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(
            adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
        )
        self.assertEqual(adapted_request.rid, "test-id")
        self.assertEqual(processed_request, self.list_req)
121

122
    def test_convert_token_ids_request(self):
123
124
        """Test converting token IDs request to internal format."""
        adapted_request, processed_request = (
125
126
            self.serving_embedding._convert_to_internal_request(
                [self.token_ids_req], ["test-id"]
127
128
129
            )
        )

130
131
132
133
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
        self.assertEqual(adapted_request.rid, "test-id")
        self.assertEqual(processed_request, self.token_ids_req)
134

135
    def test_convert_multimodal_request(self):
136
137
        """Test converting multimodal request to internal format."""
        adapted_request, processed_request = (
138
139
            self.serving_embedding._convert_to_internal_request(
                [self.multimodal_req], ["test-id"]
140
141
142
            )
        )

143
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
144
        # Should extract text and images separately
145
146
147
148
149
150
151
152
        self.assertEqual(len(adapted_request.text), 2)
        self.assertIn("Hello", adapted_request.text)
        self.assertIn("World", adapted_request.text)
        self.assertEqual(adapted_request.image_data[0], "base64_image_data")
        self.assertIsNone(adapted_request.image_data[1])
        self.assertEqual(adapted_request.rid, "test-id")

    def test_build_single_embedding_response(self):
153
154
155
156
157
158
159
160
        """Test building response for single embedding."""
        ret_data = [
            {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
                "meta_info": {"prompt_tokens": 5},
            }
        ]

161
162
163
        response = self.serving_embedding._build_embedding_response(
            ret_data, "test-model"
        )
164

165
166
167
168
169
170
171
172
173
174
175
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(response.model, "test-model")
        self.assertEqual(len(response.data), 1)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
        self.assertEqual(response.data[0].index, 0)
        self.assertEqual(response.data[0].object, "embedding")
        self.assertEqual(response.usage.prompt_tokens, 5)
        self.assertEqual(response.usage.total_tokens, 5)
        self.assertEqual(response.usage.completion_tokens, 0)

    def test_build_multiple_embedding_response(self):
176
177
178
179
180
181
182
183
184
185
186
187
        """Test building response for multiple embeddings."""
        ret_data = [
            {
                "embedding": [0.1, 0.2, 0.3],
                "meta_info": {"prompt_tokens": 3},
            },
            {
                "embedding": [0.4, 0.5, 0.6],
                "meta_info": {"prompt_tokens": 4},
            },
        ]

188
189
190
        response = self.serving_embedding._build_embedding_response(
            ret_data, "test-model"
        )
191

192
193
194
195
196
197
198
199
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(len(response.data), 2)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
        self.assertEqual(response.data[0].index, 0)
        self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
        self.assertEqual(response.data[1].index, 1)
        self.assertEqual(response.usage.prompt_tokens, 7)  # 3 + 4
        self.assertEqual(response.usage.total_tokens, 7)
200

201
    async def test_handle_request_success(self):
202
203
204
205
206
207
208
209
210
        """Test successful embedding request handling."""

        # Mock the generate_request to return expected data
        async def mock_generate():
            yield {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
                "meta_info": {"prompt_tokens": 5},
            }

211
        self.serving_embedding.tokenizer_manager.generate_request = Mock(
212
213
214
            return_value=mock_generate()
        )

215
216
        response = await self.serving_embedding.handle_request(
            self.basic_req, self.request
217
218
        )

219
220
221
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(len(response.data), 1)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
222

223
    async def test_handle_request_validation_error(self):
224
225
226
        """Test handling request with validation error."""
        invalid_request = EmbeddingRequest(model="test-model", input="")

227
228
229
        response = await self.serving_embedding.handle_request(
            invalid_request, self.request
        )
230

231
232
        self.assertIsInstance(response, ORJSONResponse)
        self.assertEqual(response.status_code, 400)
233

234
    async def test_handle_request_generation_error(self):
235
236
237
238
239
240
241
        """Test handling request with generation error."""

        # Mock generate_request to raise an error
        async def mock_generate_error():
            raise ValueError("Generation failed")
            yield  # This won't be reached but needed for async generator

242
        self.serving_embedding.tokenizer_manager.generate_request = Mock(
243
244
245
            return_value=mock_generate_error()
        )

246
247
        response = await self.serving_embedding.handle_request(
            self.basic_req, self.request
248
249
        )

250
251
        self.assertIsInstance(response, ORJSONResponse)
        self.assertEqual(response.status_code, 400)
252

253
    async def test_handle_request_internal_error(self):
254
255
256
        """Test handling request with internal server error."""
        # Mock _convert_to_internal_request to raise an exception
        with patch.object(
257
            self.serving_embedding,
258
259
260
            "_convert_to_internal_request",
            side_effect=Exception("Internal error"),
        ):
261
262
            response = await self.serving_embedding.handle_request(
                self.basic_req, self.request
263
264
            )

265
266
267
268
269
270
            self.assertIsInstance(response, ORJSONResponse)
            self.assertEqual(response.status_code, 500)


if __name__ == "__main__":
    unittest.main(verbosity=2)