""" Unit tests for the OpenAIServingEmbedding class from serving_embedding.py. These tests ensure that the embedding serving implementation maintains compatibility with the original adapter.py functionality and follows OpenAI API specifications. """ import unittest import uuid from unittest.mock import Mock from fastapi import Request from sglang.srt.entrypoints.openai.protocol import ( EmbeddingRequest, EmbeddingResponse, MultimodalEmbeddingInput, ) from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.managers.io_struct import EmbeddingReqInput # Mock TokenizerManager for embedding tests class _MockTokenizerManager: def __init__(self): self.model_config = Mock() self.model_config.is_multimodal = False self.server_args = Mock() self.server_args.enable_cache_report = False self.model_path = "test-model" # Mock tokenizer self.tokenizer = Mock() self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) self.tokenizer.decode = Mock(return_value="Test embedding input") self.tokenizer.chat_template = None self.tokenizer.bos_token_id = 1 # Mock generate_request method for embeddings async def mock_generate_embedding(): yield { "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding "meta_info": { "id": f"embd-{uuid.uuid4()}", "prompt_tokens": 5, }, } self.generate_request = Mock(return_value=mock_generate_embedding()) # Mock TemplateManager for embedding tests class _MockTemplateManager: def __init__(self): self.chat_template_name = None # None for embeddings usually self.jinja_template_content_format = None self.completion_template_name = None class ServingEmbeddingTestCase(unittest.TestCase): def setUp(self): """Set up test fixtures.""" self.tokenizer_manager = _MockTokenizerManager() self.template_manager = _MockTemplateManager() self.serving_embedding = OpenAIServingEmbedding( self.tokenizer_manager, self.template_manager ) self.request = Mock(spec=Request) self.request.headers = {} self.basic_req = EmbeddingRequest( model="test-model", input="Hello, how are you?", encoding_format="float", ) self.list_req = EmbeddingRequest( model="test-model", input=["Hello, how are you?", "I am fine, thank you!"], encoding_format="float", ) self.multimodal_req = EmbeddingRequest( model="test-model", input=[ MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), MultimodalEmbeddingInput(text="World", image=None), ], encoding_format="float", ) self.token_ids_req = EmbeddingRequest( model="test-model", input=[1, 2, 3, 4, 5], encoding_format="float", ) def test_convert_single_string_request(self): """Test converting single string request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request(self.basic_req) ) self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertEqual(adapted_request.text, "Hello, how are you?") # self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(processed_request, self.basic_req) def test_convert_list_string_request(self): """Test converting list of strings request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request(self.list_req) ) self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertEqual( adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] ) # self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(processed_request, self.list_req) def test_convert_token_ids_request(self): """Test converting token IDs request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request(self.token_ids_req) ) self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) # self.assertEqual(adapted_request.rid, "test-id") self.assertEqual(processed_request, self.token_ids_req) def test_convert_multimodal_request(self): """Test converting multimodal request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request(self.multimodal_req) ) self.assertIsInstance(adapted_request, EmbeddingReqInput) # Should extract text and images separately self.assertEqual(len(adapted_request.text), 2) self.assertIn("Hello", adapted_request.text) self.assertIn("World", adapted_request.text) self.assertEqual(adapted_request.image_data[0], "base64_image_data") self.assertIsNone(adapted_request.image_data[1]) # self.assertEqual(adapted_request.rid, "test-id") if __name__ == "__main__": unittest.main(verbosity=2)