test_serving_embedding.py 5.24 KB
Newer Older
1
2
3
4
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
"""

5
import unittest
6
import uuid
7
from unittest.mock import Mock
8
9
10
11
12
13
14
15
16
17
18
19
20

from fastapi import Request

from sglang.srt.entrypoints.openai.protocol import (
    EmbeddingRequest,
    EmbeddingResponse,
    MultimodalEmbeddingInput,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput


# Mock TokenizerManager for embedding tests
21
class _MockTokenizerManager:
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    def __init__(self):
        self.model_config = Mock()
        self.model_config.is_multimodal = False
        self.server_args = Mock()
        self.server_args.enable_cache_report = False
        self.model_path = "test-model"

        # Mock tokenizer
        self.tokenizer = Mock()
        self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
        self.tokenizer.decode = Mock(return_value="Test embedding input")
        self.tokenizer.chat_template = None
        self.tokenizer.bos_token_id = 1

        # Mock generate_request method for embeddings
        async def mock_generate_embedding():
            yield {
                "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20,  # 100-dim embedding
                "meta_info": {
                    "id": f"embd-{uuid.uuid4()}",
                    "prompt_tokens": 5,
                },
            }

        self.generate_request = Mock(return_value=mock_generate_embedding())


49
50
51
52
53
54
55
56
# Mock TemplateManager for embedding tests
class _MockTemplateManager:
    def __init__(self):
        self.chat_template_name = None  # None for embeddings usually
        self.jinja_template_content_format = None
        self.completion_template_name = None


57
58
59
60
class ServingEmbeddingTestCase(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures."""
        self.tokenizer_manager = _MockTokenizerManager()
61
62
63
64
        self.template_manager = _MockTemplateManager()
        self.serving_embedding = OpenAIServingEmbedding(
            self.tokenizer_manager, self.template_manager
        )
65

66
67
        self.request = Mock(spec=Request)
        self.request.headers = {}
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        self.basic_req = EmbeddingRequest(
            model="test-model",
            input="Hello, how are you?",
            encoding_format="float",
        )
        self.list_req = EmbeddingRequest(
            model="test-model",
            input=["Hello, how are you?", "I am fine, thank you!"],
            encoding_format="float",
        )
        self.multimodal_req = EmbeddingRequest(
            model="test-model",
            input=[
                MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
                MultimodalEmbeddingInput(text="World", image=None),
            ],
            encoding_format="float",
        )
        self.token_ids_req = EmbeddingRequest(
            model="test-model",
            input=[1, 2, 3, 4, 5],
            encoding_format="float",
        )
92

93
    def test_convert_single_string_request(self):
94
95
        """Test converting single string request to internal format."""
        adapted_request, processed_request = (
96
            self.serving_embedding._convert_to_internal_request(self.basic_req)
97
98
        )

99
100
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.text, "Hello, how are you?")
101
        # self.assertEqual(adapted_request.rid, "test-id")
102
        self.assertEqual(processed_request, self.basic_req)
103

104
    def test_convert_list_string_request(self):
105
106
        """Test converting list of strings request to internal format."""
        adapted_request, processed_request = (
107
            self.serving_embedding._convert_to_internal_request(self.list_req)
108
109
        )

110
111
112
113
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(
            adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
        )
114
        # self.assertEqual(adapted_request.rid, "test-id")
115
        self.assertEqual(processed_request, self.list_req)
116

117
    def test_convert_token_ids_request(self):
118
119
        """Test converting token IDs request to internal format."""
        adapted_request, processed_request = (
120
            self.serving_embedding._convert_to_internal_request(self.token_ids_req)
121
122
        )

123
124
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
        self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
125
        # self.assertEqual(adapted_request.rid, "test-id")
126
        self.assertEqual(processed_request, self.token_ids_req)
127

128
    def test_convert_multimodal_request(self):
129
130
        """Test converting multimodal request to internal format."""
        adapted_request, processed_request = (
131
            self.serving_embedding._convert_to_internal_request(self.multimodal_req)
132
133
        )

134
        self.assertIsInstance(adapted_request, EmbeddingReqInput)
135
        # Should extract text and images separately
136
137
138
139
140
        self.assertEqual(len(adapted_request.text), 2)
        self.assertIn("Hello", adapted_request.text)
        self.assertIn("World", adapted_request.text)
        self.assertEqual(adapted_request.image_data[0], "base64_image_data")
        self.assertIsNone(adapted_request.image_data[1])
141
        # self.assertEqual(adapted_request.rid, "test-id")
142
143
144
145


if __name__ == "__main__":
    unittest.main(verbosity=2)