test_serving_embedding.py 9.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.

These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""

import asyncio
import json
import time
11
import unittest
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch

from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError

from sglang.srt.entrypoints.openai.protocol import (
    EmbeddingObject,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorResponse,
    MultimodalEmbeddingInput,
    UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput


# Mock TokenizerManager for embedding tests
33
class _MockTokenizerManager:
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def __init__(self):
        self.model_config = Mock()
        self.model_config.is_multimodal = False
        self.server_args = Mock()
        self.server_args.enable_cache_report = False
        self.model_path = "test-model"

        # Mock tokenizer
        self.tokenizer = Mock()
        self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
        self.tokenizer.decode = Mock(return_value="Test embedding input")
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

        # Mock generate_request method for embeddings
        async def mock_generate_embedding():
            yield {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20,  # 100-dim embedding
                "meta_info": {
                    "id": f"embd-{uuid.uuid4()}",
                    "prompt_tokens": 5,
                },
            }

        self.generate_request = Mock(return_value=mock_generate_embedding())


61
62
63
64
65
class ServingEmbeddingTestCase(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures."""
        self.tokenizer_manager = _MockTokenizerManager()
        self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
66

67
68
        self.request = Mock(spec=Request)
        self.request.headers = {}
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        self.basic_req = EmbeddingRequest(
            model="test-model",
            input="Hello, how are you?",
            encoding_format="float",
        )
        self.list_req = EmbeddingRequest(
            model="test-model",
            input=["Hello, how are you?", "I am fine, thank you!"],
            encoding_format="float",
        )
        self.multimodal_req = EmbeddingRequest(
            model="test-model",
            input=[
                MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
                MultimodalEmbeddingInput(text="World", image=None),
            ],
            encoding_format="float",
        )
        self.token_ids_req = EmbeddingRequest(
            model="test-model",
            input=[1, 2, 3, 4, 5],
            encoding_format="float",
        )
93

94
    def test_convert_single_string_request(self):
95
96
        """Test converting single string request to internal format."""
        adapted_request, processed_request = (
97
            self.serving_embedding._convert_to_internal_request(self.basic_req)
98
99
        )

100
101
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.text, "Hello, how are you?")
102
        # self.assertEqual(adapted_request.rid, "test-id")
103
        self.assertEqual(processed_request, self.basic_req)
104

105
    def test_convert_list_string_request(self):
106
107
        """Test converting list of strings request to internal format."""
        adapted_request, processed_request = (
108
            self.serving_embedding._convert_to_internal_request(self.list_req)
109
110
        )

111
112
113
114
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(
            adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
        )
115
        # self.assertEqual(adapted_request.rid, "test-id")
116
        self.assertEqual(processed_request, self.list_req)
117

118
    def test_convert_token_ids_request(self):
119
120
        """Test converting token IDs request to internal format."""
        adapted_request, processed_request = (
121
            self.serving_embedding._convert_to_internal_request(self.token_ids_req)
122
123
        )

124
125
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
126
        # self.assertEqual(adapted_request.rid, "test-id")
127
        self.assertEqual(processed_request, self.token_ids_req)
128

129
    def test_convert_multimodal_request(self):
130
131
        """Test converting multimodal request to internal format."""
        adapted_request, processed_request = (
132
            self.serving_embedding._convert_to_internal_request(self.multimodal_req)
133
134
        )

135
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
136
        # Should extract text and images separately
137
138
139
140
141
        self.assertEqual(len(adapted_request.text), 2)
        self.assertIn("Hello", adapted_request.text)
        self.assertIn("World", adapted_request.text)
        self.assertEqual(adapted_request.image_data[0], "base64_image_data")
        self.assertIsNone(adapted_request.image_data[1])
142
        # self.assertEqual(adapted_request.rid, "test-id")
143
144

    def test_build_single_embedding_response(self):
145
146
147
148
149
150
151
152
        """Test building response for single embedding."""
        ret_data = [
            {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
                "meta_info": {"prompt_tokens": 5},
            }
        ]

woodx's avatar
woodx committed
153
        response = self.serving_embedding._build_embedding_response(ret_data)
154

155
156
157
158
159
160
161
162
163
164
165
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(response.model, "test-model")
        self.assertEqual(len(response.data), 1)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
        self.assertEqual(response.data[0].index, 0)
        self.assertEqual(response.data[0].object, "embedding")
        self.assertEqual(response.usage.prompt_tokens, 5)
        self.assertEqual(response.usage.total_tokens, 5)
        self.assertEqual(response.usage.completion_tokens, 0)

    def test_build_multiple_embedding_response(self):
166
167
168
169
170
171
172
173
174
175
176
177
        """Test building response for multiple embeddings."""
        ret_data = [
            {
                "embedding": [0.1, 0.2, 0.3],
                "meta_info": {"prompt_tokens": 3},
            },
            {
                "embedding": [0.4, 0.5, 0.6],
                "meta_info": {"prompt_tokens": 4},
            },
        ]

woodx's avatar
woodx committed
178
        response = self.serving_embedding._build_embedding_response(ret_data)
179

180
181
182
183
184
185
186
187
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(len(response.data), 2)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
        self.assertEqual(response.data[0].index, 0)
        self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
        self.assertEqual(response.data[1].index, 1)
        self.assertEqual(response.usage.prompt_tokens, 7)  # 3 + 4
        self.assertEqual(response.usage.total_tokens, 7)
188

189
    def test_handle_request_success(self):
190
191
        """Test successful embedding request handling."""

192
193
194
195
196
197
198
        async def run_test():
            # Mock the generate_request to return expected data
            async def mock_generate():
                yield {
                    "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
                    "meta_info": {"prompt_tokens": 5},
                }
199

200
201
202
            self.serving_embedding.tokenizer_manager.generate_request = Mock(
                return_value=mock_generate()
            )
203

204
205
206
            response = await self.serving_embedding.handle_request(
                self.basic_req, self.request
            )
207

208
209
210
211
212
            self.assertIsInstance(response, EmbeddingResponse)
            self.assertEqual(len(response.data), 1)
            self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])

        asyncio.run(run_test())
213

214
    def test_handle_request_validation_error(self):
215
216
        """Test handling request with validation error."""

217
218
        async def run_test():
            invalid_request = EmbeddingRequest(model="test-model", input="")
219

220
221
222
            response = await self.serving_embedding.handle_request(
                invalid_request, self.request
            )
223

224
225
            self.assertIsInstance(response, ORJSONResponse)
            self.assertEqual(response.status_code, 400)
226

227
        asyncio.run(run_test())
228

229
230
    def test_handle_request_generation_error(self):
        """Test handling request with generation error."""
231

232
233
234
235
236
        async def run_test():
            # Mock generate_request to raise an error
            async def mock_generate_error():
                raise ValueError("Generation failed")
                yield  # This won't be reached but needed for async generator
237

238
239
240
            self.serving_embedding.tokenizer_manager.generate_request = Mock(
                return_value=mock_generate_error()
            )
241

242
243
            response = await self.serving_embedding.handle_request(
                self.basic_req, self.request
244
245
            )

246
            self.assertIsInstance(response, ORJSONResponse)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            self.assertEqual(response.status_code, 400)

        asyncio.run(run_test())

    def test_handle_request_internal_error(self):
        """Test handling request with internal server error."""

        async def run_test():
            # Mock _convert_to_internal_request to raise an exception
            with patch.object(
                self.serving_embedding,
                "_convert_to_internal_request",
                side_effect=Exception("Internal error"),
            ):
                response = await self.serving_embedding.handle_request(
                    self.basic_req, self.request
                )

                self.assertIsInstance(response, ORJSONResponse)
                self.assertEqual(response.status_code, 500)

        asyncio.run(run_test())
269
270
271
272


if __name__ == "__main__":
    unittest.main(verbosity=2)