"torchvision/transforms/_functional_tensor.py" did not exist on "b56f17ae1ae8a5d08067c7f7444af21fb3b59ca6"
Unverified Commit 4df5fc21 authored by woodx's avatar woodx Committed by GitHub
Browse files

Feat/refactor embedding server (#7322)

parent a06912ad
......@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_prometheus_middleware,
......@@ -64,6 +65,7 @@ class AppState:
server_args: Optional[ServerArgs] = None
tokenizer_manager: Optional[TokenizerManager] = None
scheduler_info: Optional[Dict] = None
embedding_server: Optional[OpenAIServingEmbedding] = None
@asynccontextmanager
......@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
app.state.tokenizer_manager = tokenizer_manager
app.state.scheduler_info = scheduler_info
app.state.serving_embedding = OpenAIServingEmbedding(
tokenizer_manager=tokenizer_manager
)
if server_args.enable_metrics:
add_prometheus_middleware(app)
......@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
pass
try:
request_json = await raw_request.json()
request = EmbeddingRequest(**request_json)
except Exception as e:
return app.state.serving_embedding.create_error_response(
f"Invalid request body, error: {str(e)}"
)
ret = await app.state.serving_embedding.handle_request(request, raw_request)
return ret
@app.post("/v1/score")
......
......@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request(
[request], [self._generate_request_id_base(request)]
request, self._generate_request_id_base(request)
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
......@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
@abstractmethod
def _convert_to_internal_request(
self,
all_requests: List[OpenAIServingRequest],
request_ids: List[str],
request: OpenAIServingRequest,
request_id: str,
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:
......
......@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def _convert_to_internal_request(
self,
all_requests: List[EmbeddingRequest],
request_ids: List[str],
request: EmbeddingRequest,
request_id: str,
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
"""Convert OpenAI embedding request to internal format"""
prompts = [request.input for request in all_requests]
# Handle single vs multiple requests
if len(all_requests) == 1:
prompt = prompts[0]
if isinstance(prompt, str):
# Single string input
prompt = request.input
if isinstance(prompt, str):
# Single string input
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings
prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list):
if len(prompt) > 0 and isinstance(prompt[0], str):
# List of strings
prompt_kwargs = {"text": prompt}
elif len(prompt) > 0 and isinstance(
prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal embedding inputs
texts = []
images = []
for item in prompt:
# Use padding for text if None - this could be improved
texts.append(item.text if item.text is not None else "padding")
images.append(item.image if item.image is not None else None)
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(
texts, images, chat_template_name
)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {
"text": generate_prompts[0],
"image_data": images[0],
}
else:
prompt_kwargs = {
"text": generate_prompts,
"image_data": images,
}
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
# Handle multimodal embedding inputs
texts = []
images = []
for item in prompt:
# Use padding for text if None - this could be improved
texts.append(item.text if item.text is not None else "padding")
images.append(item.image if item.image is not None else None)
generate_prompts = []
# Check if we have a chat template for multimodal embeddings
# This would need to be passed in from the server configuration
chat_template_name = getattr(
self.tokenizer_manager, "chat_template_name", None
)
if chat_template_name is not None:
convs = generate_embedding_convs(texts, images, chat_template_name)
for conv in convs:
generate_prompts.append(conv.get_prompt())
else:
generate_prompts = texts
if len(generate_prompts) == 1:
prompt_kwargs = {
"text": generate_prompts[0],
"image_data": images[0],
}
else:
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
prompt_kwargs = {
"text": generate_prompts,
"image_data": images,
}
else:
# Other types (should not happen but handle gracefully)
# List of integers (token IDs) or empty list
prompt_kwargs = {"input_ids": prompt}
# Use the passed request_ids for single request
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
else:
# Handle batch requests
if len(prompts) > 0:
# Validate that all prompts have the same type
first_prompt = prompts[0]
first_type = type(first_prompt)
for i, prompt in enumerate(prompts[1:], 1):
if type(prompt) != first_type:
raise AssertionError(
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
)
if isinstance(first_prompt, str):
# Batch of strings
prompt_kwargs = {"text": prompts}
elif isinstance(first_prompt, list):
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
# Batch of lists of strings
prompt_kwargs = {"text": prompts}
elif len(first_prompt) > 0 and isinstance(
first_prompt[0], MultimodalEmbeddingInput
):
# Handle multimodal batch requests
raise NotImplementedError(
"Multiple requests with multimodal inputs are not supported yet"
)
else:
# Batch of token ID lists
prompt_kwargs = {"input_ids": prompts}
else:
# Other types
prompt_kwargs = {"input_ids": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
# Use the passed request_ids for batch requests
final_request_id = request_ids
# Other types (should not happen but handle gracefully)
prompt_kwargs = {"input_ids": prompt}
adapted_request = EmbeddingReqInput(
rid=final_request_id,
**prompt_kwargs,
)
return adapted_request, (
all_requests[0] if len(all_requests) == 1 else all_requests
)
return adapted_request, request
async def _handle_non_streaming_request(
self,
......@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
if not isinstance(ret, list):
ret = [ret]
response = self._build_embedding_response(
ret, self.tokenizer_manager.model_path
)
response = self._build_embedding_response(ret)
return response
def _build_embedding_response(
self, ret: List[Dict[str, Any]], model_path: str
) -> EmbeddingResponse:
def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
"""Build the embedding response"""
embedding_objects = []
prompt_tokens = 0
......@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return EmbeddingResponse(
data=embedding_objects,
model=model_path,
model=self.tokenizer_manager.model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
......
......@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(
[self.basic_req], ["test-id"]
self.basic_req, "test-id"
)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?")
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(adapted_request.rid, None)
self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(
[self.list_req], ["test-id"]
self.list_req, "test-id"
)
)
......@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
)
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(adapted_request.rid, None)
self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(
[self.token_ids_req], ["test-id"]
self.token_ids_req, "test-id"
)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(adapted_request.rid, None)
self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(
[self.multimodal_req], ["test-id"]
self.multimodal_req, "test-id"
)
)
......@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1])
self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(adapted_request.rid, None)
def test_build_single_embedding_response(self):
"""Test building response for single embedding."""
......@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
}
]
response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
)
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(response.model, "test-model")
......@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
},
]
response = self.serving_embedding._build_embedding_response(
ret_data, "test-model"
)
response = self.serving_embedding._build_embedding_response(ret_data)
self.assertIsInstance(response, EmbeddingResponse)
self.assertEqual(len(response.data), 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment