Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4df5fc21
Unverified
Commit
4df5fc21
authored
Jun 20, 2025
by
woodx
Committed by
GitHub
Jun 19, 2025
Browse files
Feat/refactor embedding server (#7322)
parent
a06912ad
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
120 deletions
+76
-120
python/sglang/srt/entrypoints/openai/api_server.py
python/sglang/srt/entrypoints/openai/api_server.py
+16
-2
python/sglang/srt/entrypoints/openai/serving_base.py
python/sglang/srt/entrypoints/openai/serving_base.py
+3
-3
python/sglang/srt/entrypoints/openai/serving_embedding.py
python/sglang/srt/entrypoints/openai/serving_embedding.py
+47
-101
test/srt/openai/test_serving_embedding.py
test/srt/openai/test_serving_embedding.py
+10
-14
No files found.
python/sglang/srt/entrypoints/openai/api_server.py
View file @
4df5fc21
...
@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
...
@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server
,
register_disaggregation_server
,
)
)
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
from
sglang.srt.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.metrics.func_timer
import
enable_func_timer
from
sglang.srt.metrics.func_timer
import
enable_func_timer
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.openai_api.protocol
import
EmbeddingRequest
,
ModelCard
,
ModelList
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
add_prometheus_middleware
,
add_prometheus_middleware
,
...
@@ -64,6 +65,7 @@ class AppState:
...
@@ -64,6 +65,7 @@ class AppState:
server_args
:
Optional
[
ServerArgs
]
=
None
server_args
:
Optional
[
ServerArgs
]
=
None
tokenizer_manager
:
Optional
[
TokenizerManager
]
=
None
tokenizer_manager
:
Optional
[
TokenizerManager
]
=
None
scheduler_info
:
Optional
[
Dict
]
=
None
scheduler_info
:
Optional
[
Dict
]
=
None
embedding_server
:
Optional
[
OpenAIServingEmbedding
]
=
None
@
asynccontextmanager
@
asynccontextmanager
...
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
...
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
server_args
=
server_args
)
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
server_args
=
server_args
)
app
.
state
.
tokenizer_manager
=
tokenizer_manager
app
.
state
.
tokenizer_manager
=
tokenizer_manager
app
.
state
.
scheduler_info
=
scheduler_info
app
.
state
.
scheduler_info
=
scheduler_info
app
.
state
.
serving_embedding
=
OpenAIServingEmbedding
(
tokenizer_manager
=
tokenizer_manager
)
if
server_args
.
enable_metrics
:
if
server_args
.
enable_metrics
:
add_prometheus_middleware
(
app
)
add_prometheus_middleware
(
app
)
...
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
@
app
.
post
(
"/v1/embeddings"
)
@
app
.
post
(
"/v1/embeddings"
)
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
pass
try
:
request_json
=
await
raw_request
.
json
()
request
=
EmbeddingRequest
(
**
request_json
)
except
Exception
as
e
:
return
app
.
state
.
serving_embedding
.
create_error_response
(
f
"Invalid request body, error:
{
str
(
e
)
}
"
)
ret
=
await
app
.
state
.
serving_embedding
.
handle_request
(
request
,
raw_request
)
return
ret
@
app
.
post
(
"/v1/score"
)
@
app
.
post
(
"/v1/score"
)
...
...
python/sglang/srt/entrypoints/openai/serving_base.py
View file @
4df5fc21
...
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
...
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format
# Convert to internal format
adapted_request
,
processed_request
=
self
.
_convert_to_internal_request
(
adapted_request
,
processed_request
=
self
.
_convert_to_internal_request
(
[
request
]
,
[
self
.
_generate_request_id_base
(
request
)
]
request
,
self
.
_generate_request_id_base
(
request
)
)
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
...
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
...
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
@
abstractmethod
@
abstractmethod
def
_convert_to_internal_request
(
def
_convert_to_internal_request
(
self
,
self
,
all_
request
s
:
List
[
OpenAIServingRequest
]
,
request
:
OpenAIServingRequest
,
request_id
s
:
List
[
str
]
,
request_id
:
str
,
)
->
tuple
[
)
->
tuple
[
GenerateReqInput
,
Union
[
OpenAIServingRequest
,
List
[
OpenAIServingRequest
]]
GenerateReqInput
,
Union
[
OpenAIServingRequest
,
List
[
OpenAIServingRequest
]]
]:
]:
...
...
python/sglang/srt/entrypoints/openai/serving_embedding.py
View file @
4df5fc21
...
@@ -71,15 +71,11 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -71,15 +71,11 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def
_convert_to_internal_request
(
def
_convert_to_internal_request
(
self
,
self
,
all_
request
s
:
List
[
EmbeddingRequest
]
,
request
:
EmbeddingRequest
,
request_id
s
:
List
[
str
]
,
request_id
:
str
,
)
->
tuple
[
EmbeddingReqInput
,
Union
[
EmbeddingRequest
,
List
[
EmbeddingRequest
]]]:
)
->
tuple
[
EmbeddingReqInput
,
Union
[
EmbeddingRequest
,
List
[
EmbeddingRequest
]]]:
"""Convert OpenAI embedding request to internal format"""
"""Convert OpenAI embedding request to internal format"""
prompts
=
[
request
.
input
for
request
in
all_requests
]
prompt
=
request
.
input
# Handle single vs multiple requests
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
if
isinstance
(
prompt
,
str
):
if
isinstance
(
prompt
,
str
):
# Single string input
# Single string input
prompt_kwargs
=
{
"text"
:
prompt
}
prompt_kwargs
=
{
"text"
:
prompt
}
...
@@ -87,9 +83,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -87,9 +83,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
if
len
(
prompt
)
>
0
and
isinstance
(
prompt
[
0
],
str
):
if
len
(
prompt
)
>
0
and
isinstance
(
prompt
[
0
],
str
):
# List of strings
# List of strings
prompt_kwargs
=
{
"text"
:
prompt
}
prompt_kwargs
=
{
"text"
:
prompt
}
elif
len
(
prompt
)
>
0
and
isinstance
(
elif
len
(
prompt
)
>
0
and
isinstance
(
prompt
[
0
],
MultimodalEmbeddingInput
):
prompt
[
0
],
MultimodalEmbeddingInput
):
# Handle multimodal embedding inputs
# Handle multimodal embedding inputs
texts
=
[]
texts
=
[]
images
=
[]
images
=
[]
...
@@ -105,9 +99,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -105,9 +99,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
self
.
tokenizer_manager
,
"chat_template_name"
,
None
self
.
tokenizer_manager
,
"chat_template_name"
,
None
)
)
if
chat_template_name
is
not
None
:
if
chat_template_name
is
not
None
:
convs
=
generate_embedding_convs
(
convs
=
generate_embedding_convs
(
texts
,
images
,
chat_template_name
)
texts
,
images
,
chat_template_name
)
for
conv
in
convs
:
for
conv
in
convs
:
generate_prompts
.
append
(
conv
.
get_prompt
())
generate_prompts
.
append
(
conv
.
get_prompt
())
else
:
else
:
...
@@ -129,53 +121,11 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -129,53 +121,11 @@ class OpenAIServingEmbedding(OpenAIServingBase):
else
:
else
:
# Other types (should not happen but handle gracefully)
# Other types (should not happen but handle gracefully)
prompt_kwargs
=
{
"input_ids"
:
prompt
}
prompt_kwargs
=
{
"input_ids"
:
prompt
}
# Use the passed request_ids for single request
final_request_id
=
request_ids
[
0
]
if
len
(
all_requests
)
==
1
else
request_ids
else
:
# Handle batch requests
if
len
(
prompts
)
>
0
:
# Validate that all prompts have the same type
first_prompt
=
prompts
[
0
]
first_type
=
type
(
first_prompt
)
for
i
,
prompt
in
enumerate
(
prompts
[
1
:],
1
):
if
type
(
prompt
)
!=
first_type
:
raise
AssertionError
(
f
"All prompts in batch must have the same type, but prompt at index
{
i
}
has different type"
)
if
isinstance
(
first_prompt
,
str
):
# Batch of strings
prompt_kwargs
=
{
"text"
:
prompts
}
elif
isinstance
(
first_prompt
,
list
):
if
len
(
first_prompt
)
>
0
and
isinstance
(
first_prompt
[
0
],
str
):
# Batch of lists of strings
prompt_kwargs
=
{
"text"
:
prompts
}
elif
len
(
first_prompt
)
>
0
and
isinstance
(
first_prompt
[
0
],
MultimodalEmbeddingInput
):
# Handle multimodal batch requests
raise
NotImplementedError
(
"Multiple requests with multimodal inputs are not supported yet"
)
else
:
# Batch of token ID lists
prompt_kwargs
=
{
"input_ids"
:
prompts
}
else
:
# Other types
prompt_kwargs
=
{
"input_ids"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
# Use the passed request_ids for batch requests
final_request_id
=
request_ids
adapted_request
=
EmbeddingReqInput
(
adapted_request
=
EmbeddingReqInput
(
rid
=
final_request_id
,
**
prompt_kwargs
,
**
prompt_kwargs
,
)
)
return
adapted_request
,
(
return
adapted_request
,
request
all_requests
[
0
]
if
len
(
all_requests
)
==
1
else
all_requests
)
async
def
_handle_non_streaming_request
(
async
def
_handle_non_streaming_request
(
self
,
self
,
...
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
response
=
self
.
_build_embedding_response
(
response
=
self
.
_build_embedding_response
(
ret
)
ret
,
self
.
tokenizer_manager
.
model_path
)
return
response
return
response
def
_build_embedding_response
(
def
_build_embedding_response
(
self
,
ret
:
List
[
Dict
[
str
,
Any
]])
->
EmbeddingResponse
:
self
,
ret
:
List
[
Dict
[
str
,
Any
]],
model_path
:
str
)
->
EmbeddingResponse
:
"""Build the embedding response"""
"""Build the embedding response"""
embedding_objects
=
[]
embedding_objects
=
[]
prompt_tokens
=
0
prompt_tokens
=
0
...
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return
EmbeddingResponse
(
return
EmbeddingResponse
(
data
=
embedding_objects
,
data
=
embedding_objects
,
model
=
model_path
,
model
=
self
.
tokenizer_manager
.
model_path
,
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
...
...
test/srt/openai/test_serving_embedding.py
View file @
4df5fc21
...
@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
"""Test converting single string request to internal format."""
"""Test converting single string request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
basic_req
]
,
[
"test-id"
]
self
.
basic_req
,
"test-id"
)
)
)
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertEqual
(
adapted_request
.
text
,
"Hello, how are you?"
)
self
.
assertEqual
(
adapted_request
.
text
,
"Hello, how are you?"
)
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
self
.
assertEqual
(
processed_request
,
self
.
basic_req
)
self
.
assertEqual
(
processed_request
,
self
.
basic_req
)
def
test_convert_list_string_request
(
self
):
def
test_convert_list_string_request
(
self
):
"""Test converting list of strings request to internal format."""
"""Test converting list of strings request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
list_req
]
,
[
"test-id"
]
self
.
list_req
,
"test-id"
)
)
)
)
...
@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
adapted_request
.
text
,
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
adapted_request
.
text
,
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
)
)
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
self
.
assertEqual
(
processed_request
,
self
.
list_req
)
self
.
assertEqual
(
processed_request
,
self
.
list_req
)
def
test_convert_token_ids_request
(
self
):
def
test_convert_token_ids_request
(
self
):
"""Test converting token IDs request to internal format."""
"""Test converting token IDs request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
token_ids_req
]
,
[
"test-id"
]
self
.
token_ids_req
,
"test-id"
)
)
)
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertEqual
(
adapted_request
.
input_ids
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertEqual
(
adapted_request
.
input_ids
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
self
.
assertEqual
(
processed_request
,
self
.
token_ids_req
)
self
.
assertEqual
(
processed_request
,
self
.
token_ids_req
)
def
test_convert_multimodal_request
(
self
):
def
test_convert_multimodal_request
(
self
):
"""Test converting multimodal request to internal format."""
"""Test converting multimodal request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
multimodal_req
]
,
[
"test-id"
]
self
.
multimodal_req
,
"test-id"
)
)
)
)
...
@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self
.
assertIn
(
"World"
,
adapted_request
.
text
)
self
.
assertIn
(
"World"
,
adapted_request
.
text
)
self
.
assertEqual
(
adapted_request
.
image_data
[
0
],
"base64_image_data"
)
self
.
assertEqual
(
adapted_request
.
image_data
[
0
],
"base64_image_data"
)
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
def
test_build_single_embedding_response
(
self
):
def
test_build_single_embedding_response
(
self
):
"""Test building response for single embedding."""
"""Test building response for single embedding."""
...
@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
}
}
]
]
response
=
self
.
serving_embedding
.
_build_embedding_response
(
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
)
ret_data
,
"test-model"
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
...
@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
},
},
]
]
response
=
self
.
serving_embedding
.
_build_embedding_response
(
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
)
ret_data
,
"test-model"
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
len
(
response
.
data
),
2
)
self
.
assertEqual
(
len
(
response
.
data
),
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment