Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4df5fc21
Unverified
Commit
4df5fc21
authored
Jun 20, 2025
by
woodx
Committed by
GitHub
Jun 19, 2025
Browse files
Feat/refactor embedding server (#7322)
parent
a06912ad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
120 deletions
+76
-120
python/sglang/srt/entrypoints/openai/api_server.py
python/sglang/srt/entrypoints/openai/api_server.py
+16
-2
python/sglang/srt/entrypoints/openai/serving_base.py
python/sglang/srt/entrypoints/openai/serving_base.py
+3
-3
python/sglang/srt/entrypoints/openai/serving_embedding.py
python/sglang/srt/entrypoints/openai/serving_embedding.py
+47
-101
test/srt/openai/test_serving_embedding.py
test/srt/openai/test_serving_embedding.py
+10
-14
No files found.
python/sglang/srt/entrypoints/openai/api_server.py
View file @
4df5fc21
...
@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
...
@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server
,
register_disaggregation_server
,
)
)
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
from
sglang.srt.entrypoints.engine
import
Engine
,
_launch_subprocesses
from
sglang.srt.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.metrics.func_timer
import
enable_func_timer
from
sglang.srt.metrics.func_timer
import
enable_func_timer
from
sglang.srt.openai_api.protocol
import
ModelCard
,
ModelList
from
sglang.srt.openai_api.protocol
import
EmbeddingRequest
,
ModelCard
,
ModelList
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
add_prometheus_middleware
,
add_prometheus_middleware
,
...
@@ -64,6 +65,7 @@ class AppState:
...
@@ -64,6 +65,7 @@ class AppState:
server_args
:
Optional
[
ServerArgs
]
=
None
server_args
:
Optional
[
ServerArgs
]
=
None
tokenizer_manager
:
Optional
[
TokenizerManager
]
=
None
tokenizer_manager
:
Optional
[
TokenizerManager
]
=
None
scheduler_info
:
Optional
[
Dict
]
=
None
scheduler_info
:
Optional
[
Dict
]
=
None
embedding_server
:
Optional
[
OpenAIServingEmbedding
]
=
None
@
asynccontextmanager
@
asynccontextmanager
...
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
...
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
server_args
=
server_args
)
tokenizer_manager
,
scheduler_info
=
_launch_subprocesses
(
server_args
=
server_args
)
app
.
state
.
tokenizer_manager
=
tokenizer_manager
app
.
state
.
tokenizer_manager
=
tokenizer_manager
app
.
state
.
scheduler_info
=
scheduler_info
app
.
state
.
scheduler_info
=
scheduler_info
app
.
state
.
serving_embedding
=
OpenAIServingEmbedding
(
tokenizer_manager
=
tokenizer_manager
)
if
server_args
.
enable_metrics
:
if
server_args
.
enable_metrics
:
add_prometheus_middleware
(
app
)
add_prometheus_middleware
(
app
)
...
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
@
app
.
post
(
"/v1/embeddings"
)
@
app
.
post
(
"/v1/embeddings"
)
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
pass
try
:
request_json
=
await
raw_request
.
json
()
request
=
EmbeddingRequest
(
**
request_json
)
except
Exception
as
e
:
return
app
.
state
.
serving_embedding
.
create_error_response
(
f
"Invalid request body, error:
{
str
(
e
)
}
"
)
ret
=
await
app
.
state
.
serving_embedding
.
handle_request
(
request
,
raw_request
)
return
ret
@
app
.
post
(
"/v1/score"
)
@
app
.
post
(
"/v1/score"
)
...
...
python/sglang/srt/entrypoints/openai/serving_base.py
View file @
4df5fc21
...
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
...
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
# Convert to internal format
# Convert to internal format
adapted_request
,
processed_request
=
self
.
_convert_to_internal_request
(
adapted_request
,
processed_request
=
self
.
_convert_to_internal_request
(
[
request
]
,
[
self
.
_generate_request_id_base
(
request
)
]
request
,
self
.
_generate_request_id_base
(
request
)
)
)
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
...
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
...
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
@
abstractmethod
@
abstractmethod
def
_convert_to_internal_request
(
def
_convert_to_internal_request
(
self
,
self
,
all_
request
s
:
List
[
OpenAIServingRequest
]
,
request
:
OpenAIServingRequest
,
request_id
s
:
List
[
str
]
,
request_id
:
str
,
)
->
tuple
[
)
->
tuple
[
GenerateReqInput
,
Union
[
OpenAIServingRequest
,
List
[
OpenAIServingRequest
]]
GenerateReqInput
,
Union
[
OpenAIServingRequest
,
List
[
OpenAIServingRequest
]]
]:
]:
...
...
python/sglang/srt/entrypoints/openai/serving_embedding.py
View file @
4df5fc21
...
@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
def
_convert_to_internal_request
(
def
_convert_to_internal_request
(
self
,
self
,
all_
request
s
:
List
[
EmbeddingRequest
]
,
request
:
EmbeddingRequest
,
request_id
s
:
List
[
str
]
,
request_id
:
str
,
)
->
tuple
[
EmbeddingReqInput
,
Union
[
EmbeddingRequest
,
List
[
EmbeddingRequest
]]]:
)
->
tuple
[
EmbeddingReqInput
,
Union
[
EmbeddingRequest
,
List
[
EmbeddingRequest
]]]:
"""Convert OpenAI embedding request to internal format"""
"""Convert OpenAI embedding request to internal format"""
prompt
s
=
[
request
.
input
for
request
in
all_requests
]
prompt
=
request
.
input
if
isinstance
(
prompt
,
str
):
# Hand
le sing
le vs multiple requests
# Sing
le s
tr
ing
input
if
len
(
all_requests
)
==
1
:
prompt_kwargs
=
{
"text"
:
prompt
}
prompt
=
prompts
[
0
]
elif
isinstance
(
prompt
,
list
):
if
isinstance
(
prompt
,
str
):
if
len
(
prompt
)
>
0
and
isinstance
(
prompt
[
0
]
,
str
):
#
Single
string
input
#
List of
string
s
prompt_kwargs
=
{
"text"
:
prompt
}
prompt_kwargs
=
{
"text"
:
prompt
}
elif
isinstance
(
prompt
,
list
):
elif
len
(
prompt
)
>
0
and
isinstance
(
prompt
[
0
],
MultimodalEmbeddingInput
):
if
len
(
prompt
)
>
0
and
isinstance
(
prompt
[
0
],
str
):
# Handle multimodal embedding inputs
# List of strings
texts
=
[]
prompt_kwargs
=
{
"text"
:
prompt
}
images
=
[]
elif
len
(
prompt
)
>
0
and
isinstance
(
for
item
in
prompt
:
prompt
[
0
],
MultimodalEmbeddingInput
# Use padding for text if None - this could be improved
):
texts
.
append
(
item
.
text
if
item
.
text
is
not
None
else
"padding"
)
# Handle multimodal embedding inputs
images
.
append
(
item
.
image
if
item
.
image
is
not
None
else
None
)
texts
=
[]
images
=
[]
generate_prompts
=
[]
for
item
in
prompt
:
# Check if we have a chat template for multimodal embeddings
# Use padding for text if None - this could be improved
# This would need to be passed in from the server configuration
texts
.
append
(
item
.
text
if
item
.
text
is
not
None
else
"padding"
)
chat_template_name
=
getattr
(
images
.
append
(
item
.
image
if
item
.
image
is
not
None
else
None
)
self
.
tokenizer_manager
,
"chat_template_name"
,
None
)
generate_prompts
=
[]
if
chat_template_name
is
not
None
:
# Check if we have a chat template for multimodal embeddings
convs
=
generate_embedding_convs
(
texts
,
images
,
chat_template_name
)
# This would need to be passed in from the server configuration
for
conv
in
convs
:
chat_template_name
=
getattr
(
generate_prompts
.
append
(
conv
.
get_prompt
())
self
.
tokenizer_manager
,
"chat_template_name"
,
None
else
:
)
generate_prompts
=
texts
if
chat_template_name
is
not
None
:
convs
=
generate_embedding_convs
(
if
len
(
generate_prompts
)
==
1
:
texts
,
images
,
chat_template_name
prompt_kwargs
=
{
)
"text"
:
generate_prompts
[
0
],
for
conv
in
convs
:
"image_data"
:
images
[
0
],
generate_prompts
.
append
(
conv
.
get_prompt
())
}
else
:
generate_prompts
=
texts
if
len
(
generate_prompts
)
==
1
:
prompt_kwargs
=
{
"text"
:
generate_prompts
[
0
],
"image_data"
:
images
[
0
],
}
else
:
prompt_kwargs
=
{
"text"
:
generate_prompts
,
"image_data"
:
images
,
}
else
:
else
:
# List of integers (token IDs) or empty list
prompt_kwargs
=
{
prompt_kwargs
=
{
"input_ids"
:
prompt
}
"text"
:
generate_prompts
,
"image_data"
:
images
,
}
else
:
else
:
#
Other types (should not happen but handle gracefully)
#
List of integers (token IDs) or empty list
prompt_kwargs
=
{
"input_ids"
:
prompt
}
prompt_kwargs
=
{
"input_ids"
:
prompt
}
# Use the passed request_ids for single request
final_request_id
=
request_ids
[
0
]
if
len
(
all_requests
)
==
1
else
request_ids
else
:
else
:
# Handle batch requests
# Other types (should not happen but handle gracefully)
if
len
(
prompts
)
>
0
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
# Validate that all prompts have the same type
first_prompt
=
prompts
[
0
]
first_type
=
type
(
first_prompt
)
for
i
,
prompt
in
enumerate
(
prompts
[
1
:],
1
):
if
type
(
prompt
)
!=
first_type
:
raise
AssertionError
(
f
"All prompts in batch must have the same type, but prompt at index
{
i
}
has different type"
)
if
isinstance
(
first_prompt
,
str
):
# Batch of strings
prompt_kwargs
=
{
"text"
:
prompts
}
elif
isinstance
(
first_prompt
,
list
):
if
len
(
first_prompt
)
>
0
and
isinstance
(
first_prompt
[
0
],
str
):
# Batch of lists of strings
prompt_kwargs
=
{
"text"
:
prompts
}
elif
len
(
first_prompt
)
>
0
and
isinstance
(
first_prompt
[
0
],
MultimodalEmbeddingInput
):
# Handle multimodal batch requests
raise
NotImplementedError
(
"Multiple requests with multimodal inputs are not supported yet"
)
else
:
# Batch of token ID lists
prompt_kwargs
=
{
"input_ids"
:
prompts
}
else
:
# Other types
prompt_kwargs
=
{
"input_ids"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
# Use the passed request_ids for batch requests
final_request_id
=
request_ids
adapted_request
=
EmbeddingReqInput
(
adapted_request
=
EmbeddingReqInput
(
rid
=
final_request_id
,
**
prompt_kwargs
,
**
prompt_kwargs
,
)
)
return
adapted_request
,
(
return
adapted_request
,
request
all_requests
[
0
]
if
len
(
all_requests
)
==
1
else
all_requests
)
async
def
_handle_non_streaming_request
(
async
def
_handle_non_streaming_request
(
self
,
self
,
...
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
response
=
self
.
_build_embedding_response
(
response
=
self
.
_build_embedding_response
(
ret
)
ret
,
self
.
tokenizer_manager
.
model_path
)
return
response
return
response
def
_build_embedding_response
(
def
_build_embedding_response
(
self
,
ret
:
List
[
Dict
[
str
,
Any
]])
->
EmbeddingResponse
:
self
,
ret
:
List
[
Dict
[
str
,
Any
]],
model_path
:
str
)
->
EmbeddingResponse
:
"""Build the embedding response"""
"""Build the embedding response"""
embedding_objects
=
[]
embedding_objects
=
[]
prompt_tokens
=
0
prompt_tokens
=
0
...
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
...
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
return
EmbeddingResponse
(
return
EmbeddingResponse
(
data
=
embedding_objects
,
data
=
embedding_objects
,
model
=
model_path
,
model
=
self
.
tokenizer_manager
.
model_path
,
usage
=
UsageInfo
(
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
prompt_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
...
...
test/srt/openai/test_serving_embedding.py
View file @
4df5fc21
...
@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
"""Test converting single string request to internal format."""
"""Test converting single string request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
basic_req
]
,
[
"test-id"
]
self
.
basic_req
,
"test-id"
)
)
)
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertEqual
(
adapted_request
.
text
,
"Hello, how are you?"
)
self
.
assertEqual
(
adapted_request
.
text
,
"Hello, how are you?"
)
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
self
.
assertEqual
(
processed_request
,
self
.
basic_req
)
self
.
assertEqual
(
processed_request
,
self
.
basic_req
)
def
test_convert_list_string_request
(
self
):
def
test_convert_list_string_request
(
self
):
"""Test converting list of strings request to internal format."""
"""Test converting list of strings request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
list_req
]
,
[
"test-id"
]
self
.
list_req
,
"test-id"
)
)
)
)
...
@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self
.
assertEqual
(
self
.
assertEqual
(
adapted_request
.
text
,
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
adapted_request
.
text
,
[
"Hello, how are you?"
,
"I am fine, thank you!"
]
)
)
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
self
.
assertEqual
(
processed_request
,
self
.
list_req
)
self
.
assertEqual
(
processed_request
,
self
.
list_req
)
def
test_convert_token_ids_request
(
self
):
def
test_convert_token_ids_request
(
self
):
"""Test converting token IDs request to internal format."""
"""Test converting token IDs request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
token_ids_req
]
,
[
"test-id"
]
self
.
token_ids_req
,
"test-id"
)
)
)
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertIsInstance
(
adapted_request
,
EmbeddingReqInput
)
self
.
assertEqual
(
adapted_request
.
input_ids
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertEqual
(
adapted_request
.
input_ids
,
[
1
,
2
,
3
,
4
,
5
])
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
self
.
assertEqual
(
processed_request
,
self
.
token_ids_req
)
self
.
assertEqual
(
processed_request
,
self
.
token_ids_req
)
def
test_convert_multimodal_request
(
self
):
def
test_convert_multimodal_request
(
self
):
"""Test converting multimodal request to internal format."""
"""Test converting multimodal request to internal format."""
adapted_request
,
processed_request
=
(
adapted_request
,
processed_request
=
(
self
.
serving_embedding
.
_convert_to_internal_request
(
self
.
serving_embedding
.
_convert_to_internal_request
(
[
self
.
multimodal_req
]
,
[
"test-id"
]
self
.
multimodal_req
,
"test-id"
)
)
)
)
...
@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
self
.
assertIn
(
"World"
,
adapted_request
.
text
)
self
.
assertIn
(
"World"
,
adapted_request
.
text
)
self
.
assertEqual
(
adapted_request
.
image_data
[
0
],
"base64_image_data"
)
self
.
assertEqual
(
adapted_request
.
image_data
[
0
],
"base64_image_data"
)
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
self
.
assertIsNone
(
adapted_request
.
image_data
[
1
])
self
.
assertEqual
(
adapted_request
.
rid
,
"test-id"
)
self
.
assertEqual
(
adapted_request
.
rid
,
None
)
def
test_build_single_embedding_response
(
self
):
def
test_build_single_embedding_response
(
self
):
"""Test building response for single embedding."""
"""Test building response for single embedding."""
...
@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
}
}
]
]
response
=
self
.
serving_embedding
.
_build_embedding_response
(
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
)
ret_data
,
"test-model"
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
self
.
assertEqual
(
response
.
model
,
"test-model"
)
...
@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
...
@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
},
},
]
]
response
=
self
.
serving_embedding
.
_build_embedding_response
(
response
=
self
.
serving_embedding
.
_build_embedding_response
(
ret_data
)
ret_data
,
"test-model"
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertIsInstance
(
response
,
EmbeddingResponse
)
self
.
assertEqual
(
len
(
response
.
data
),
2
)
self
.
assertEqual
(
len
(
response
.
data
),
2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment