test_serving_embedding.py 9.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.

These tests ensure that the embedding serving implementation maintains compatibility
with the original adapter.py functionality and follows OpenAI API specifications.
"""

import asyncio
import json
import time
11
import unittest
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import uuid
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch

from fastapi import Request
from fastapi.responses import ORJSONResponse
from pydantic_core import ValidationError

from sglang.srt.entrypoints.openai.protocol import (
    EmbeddingObject,
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorResponse,
    MultimodalEmbeddingInput,
    UsageInfo,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput


# Mock TokenizerManager for embedding tests
33
class _MockTokenizerManager:
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def __init__(self):
        self.model_config = Mock()
        self.model_config.is_multimodal = False
        self.server_args = Mock()
        self.server_args.enable_cache_report = False
        self.model_path = "test-model"

        # Mock tokenizer
        self.tokenizer = Mock()
        self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
        self.tokenizer.decode = Mock(return_value="Test embedding input")
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

        # Mock generate_request method for embeddings
        async def mock_generate_embedding():
            yield {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20,  # 100-dim embedding
                "meta_info": {
                    "id": f"embd-{uuid.uuid4()}",
                    "prompt_tokens": 5,
                },
            }

        self.generate_request = Mock(return_value=mock_generate_embedding())


61
62
63
64
65
class ServingEmbeddingTestCase(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures."""
        self.tokenizer_manager = _MockTokenizerManager()
        self.serving_embedding = OpenAIServingEmbedding(self.tokenizer_manager)
66

67
68
        self.request = Mock(spec=Request)
        self.request.headers = {}
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        self.basic_req = EmbeddingRequest(
            model="test-model",
            input="Hello, how are you?",
            encoding_format="float",
        )
        self.list_req = EmbeddingRequest(
            model="test-model",
            input=["Hello, how are you?", "I am fine, thank you!"],
            encoding_format="float",
        )
        self.multimodal_req = EmbeddingRequest(
            model="test-model",
            input=[
                MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
                MultimodalEmbeddingInput(text="World", image=None),
            ],
            encoding_format="float",
        )
        self.token_ids_req = EmbeddingRequest(
            model="test-model",
            input=[1, 2, 3, 4, 5],
            encoding_format="float",
        )
93

94
    def test_convert_single_string_request(self):
95
96
        """Test converting single string request to internal format."""
        adapted_request, processed_request = (
97
            self.serving_embedding._convert_to_internal_request(
woodx's avatar
woodx committed
98
                self.basic_req, "test-id"
99
100
101
            )
        )

102
103
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.text, "Hello, how are you?")
woodx's avatar
woodx committed
104
        self.assertEqual(adapted_request.rid, None)
105
        self.assertEqual(processed_request, self.basic_req)
106

107
    def test_convert_list_string_request(self):
108
109
        """Test converting list of strings request to internal format."""
        adapted_request, processed_request = (
110
            self.serving_embedding._convert_to_internal_request(
woodx's avatar
woodx committed
111
                self.list_req, "test-id"
112
113
114
            )
        )

115
116
117
118
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(
            adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
        )
woodx's avatar
woodx committed
119
        self.assertEqual(adapted_request.rid, None)
120
        self.assertEqual(processed_request, self.list_req)
121

122
    def test_convert_token_ids_request(self):
123
124
        """Test converting token IDs request to internal format."""
        adapted_request, processed_request = (
125
            self.serving_embedding._convert_to_internal_request(
woodx's avatar
woodx committed
126
                self.token_ids_req, "test-id"
127
128
129
            )
        )

130
131
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
woodx's avatar
woodx committed
132
        self.assertEqual(adapted_request.rid, None)
133
        self.assertEqual(processed_request, self.token_ids_req)
134

135
    def test_convert_multimodal_request(self):
136
137
        """Test converting multimodal request to internal format."""
        adapted_request, processed_request = (
138
            self.serving_embedding._convert_to_internal_request(
woodx's avatar
woodx committed
139
                self.multimodal_req, "test-id"
140
141
142
            )
        )

143
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
144
        # Should extract text and images separately
145
146
147
148
149
        self.assertEqual(len(adapted_request.text), 2)
        self.assertIn("Hello", adapted_request.text)
        self.assertIn("World", adapted_request.text)
        self.assertEqual(adapted_request.image_data[0], "base64_image_data")
        self.assertIsNone(adapted_request.image_data[1])
woodx's avatar
woodx committed
150
        self.assertEqual(adapted_request.rid, None)
151
152

    def test_build_single_embedding_response(self):
153
154
155
156
157
158
159
160
        """Test building response for single embedding."""
        ret_data = [
            {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
                "meta_info": {"prompt_tokens": 5},
            }
        ]

woodx's avatar
woodx committed
161
        response = self.serving_embedding._build_embedding_response(ret_data)
162

163
164
165
166
167
168
169
170
171
172
173
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(response.model, "test-model")
        self.assertEqual(len(response.data), 1)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
        self.assertEqual(response.data[0].index, 0)
        self.assertEqual(response.data[0].object, "embedding")
        self.assertEqual(response.usage.prompt_tokens, 5)
        self.assertEqual(response.usage.total_tokens, 5)
        self.assertEqual(response.usage.completion_tokens, 0)

    def test_build_multiple_embedding_response(self):
174
175
176
177
178
179
180
181
182
183
184
185
        """Test building response for multiple embeddings."""
        ret_data = [
            {
                "embedding": [0.1, 0.2, 0.3],
                "meta_info": {"prompt_tokens": 3},
            },
            {
                "embedding": [0.4, 0.5, 0.6],
                "meta_info": {"prompt_tokens": 4},
            },
        ]

woodx's avatar
woodx committed
186
        response = self.serving_embedding._build_embedding_response(ret_data)
187

188
189
190
191
192
193
194
195
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(len(response.data), 2)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
        self.assertEqual(response.data[0].index, 0)
        self.assertEqual(response.data[1].embedding, [0.4, 0.5, 0.6])
        self.assertEqual(response.data[1].index, 1)
        self.assertEqual(response.usage.prompt_tokens, 7)  # 3 + 4
        self.assertEqual(response.usage.total_tokens, 7)
196

197
    async def test_handle_request_success(self):
198
199
200
201
202
203
204
205
206
        """Test successful embedding request handling."""

        # Mock the generate_request to return expected data
        async def mock_generate():
            yield {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
                "meta_info": {"prompt_tokens": 5},
            }

207
        self.serving_embedding.tokenizer_manager.generate_request = Mock(
208
209
210
            return_value=mock_generate()
        )

211
212
        response = await self.serving_embedding.handle_request(
            self.basic_req, self.request
213
214
        )

215
216
217
        self.assertIsInstance(response, EmbeddingResponse)
        self.assertEqual(len(response.data), 1)
        self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3, 0.4, 0.5])
218

219
    async def test_handle_request_validation_error(self):
220
221
222
        """Test handling request with validation error."""
        invalid_request = EmbeddingRequest(model="test-model", input="")

223
224
225
        response = await self.serving_embedding.handle_request(
            invalid_request, self.request
        )
226

227
228
        self.assertIsInstance(response, ORJSONResponse)
        self.assertEqual(response.status_code, 400)
229

230
    async def test_handle_request_generation_error(self):
231
232
233
234
235
236
237
        """Test handling request with generation error."""

        # Mock generate_request to raise an error
        async def mock_generate_error():
            raise ValueError("Generation failed")
            yield  # This won't be reached but needed for async generator

238
        self.serving_embedding.tokenizer_manager.generate_request = Mock(
239
240
241
            return_value=mock_generate_error()
        )

242
243
        response = await self.serving_embedding.handle_request(
            self.basic_req, self.request
244
245
        )

246
247
        self.assertIsInstance(response, ORJSONResponse)
        self.assertEqual(response.status_code, 400)
248

249
    async def test_handle_request_internal_error(self):
250
251
252
        """Test handling request with internal server error."""
        # Mock _convert_to_internal_request to raise an exception
        with patch.object(
253
            self.serving_embedding,
254
255
256
            "_convert_to_internal_request",
            side_effect=Exception("Internal error"),
        ):
257
258
            response = await self.serving_embedding.handle_request(
                self.basic_req, self.request
259
260
            )

261
262
263
264
265
266
            self.assertIsInstance(response, ORJSONResponse)
            self.assertEqual(response.status_code, 500)


if __name__ == "__main__":
    unittest.main(verbosity=2)