Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d1112d85
Unverified
Commit
d1112d85
authored
Mar 17, 2025
by
Rin Intachuen
Committed by
GitHub
Mar 16, 2025
Browse files
Add endpoint for file support, purely to speed up processing of input_embeds. (#2797)
parent
48efec7b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
4 deletions
+63
-4
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+24
-0
test/srt/test_input_embeddings.py
test/srt/test_input_embeddings.py
+39
-4
No files found.
python/sglang/srt/entrypoints/http_server.py
View file @
d1112d85
...
@@ -19,6 +19,7 @@ This file implements HTTP APIs for the inference engine via fastapi.
...
@@ -19,6 +19,7 @@ This file implements HTTP APIs for the inference engine via fastapi.
import
asyncio
import
asyncio
import
dataclasses
import
dataclasses
import
json
import
logging
import
logging
import
multiprocessing
as
multiprocessing
import
multiprocessing
as
multiprocessing
import
os
import
os
...
@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request):
...
@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/generate_from_file"
,
methods
=
[
"POST"
])
async
def
generate_from_file_request
(
file
:
UploadFile
,
request
:
Request
):
"""Handle a generate request, this is purely to work with input_embeds."""
content
=
await
file
.
read
()
input_embeds
=
json
.
loads
(
content
.
decode
(
"utf-8"
))
obj
=
GenerateReqInput
(
input_embeds
=
input_embeds
,
sampling_params
=
{
"repetition_penalty"
:
1.2
,
"temperature"
:
0.2
,
"max_new_tokens"
:
512
,
},
)
try
:
ret
=
await
_global_state
.
generate_request
(
obj
,
request
).
__anext__
()
return
ret
except
ValueError
as
e
:
logger
.
error
(
f
"Error:
{
e
}
"
)
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
@
app
.
api_route
(
"/encode"
,
methods
=
[
"POST"
,
"PUT"
])
async
def
encode_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
async
def
encode_request
(
obj
:
EmbeddingReqInput
,
request
:
Request
):
"""Handle an embedding request."""
"""Handle an embedding request."""
...
...
test/srt/test_input_embeddings.py
View file @
d1112d85
import
json
import
json
import
os
import
tempfile
import
unittest
import
unittest
import
requests
import
requests
...
@@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase):
...
@@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase):
return
embeddings
.
squeeze
().
tolist
()
# Convert tensor to a list for API use
return
embeddings
.
squeeze
().
tolist
()
# Convert tensor to a list for API use
def
send_request
(
self
,
payload
):
def
send_request
(
self
,
payload
):
"""Send a POST request to the
API
and return the response."""
"""Send a POST request to the
/generate endpoint
and return the response."""
response
=
requests
.
post
(
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
payload
,
json
=
payload
,
...
@@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase):
...
@@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase):
"error"
:
f
"Request failed with status
{
response
.
status_code
}
:
{
response
.
text
}
"
"error"
:
f
"Request failed with status
{
response
.
status_code
}
:
{
response
.
text
}
"
}
}
def
send_file_request
(
self
,
file_path
):
"""Send a POST request to the /generate_from_file endpoint with a file."""
with
open
(
file_path
,
"rb"
)
as
f
:
response
=
requests
.
post
(
self
.
base_url
+
"/generate_from_file"
,
files
=
{
"file"
:
f
},
timeout
=
30
,
# Set a reasonable timeout for the API request
)
if
response
.
status_code
==
200
:
return
response
.
json
()
return
{
"error"
:
f
"Request failed with status
{
response
.
status_code
}
:
{
response
.
text
}
"
}
def
test_text_based_response
(
self
):
def
test_text_based_response
(
self
):
"""
P
rint API response using text-based input."""
"""
Test and p
rint API response
s
using text-based input."""
for
text
in
self
.
texts
:
for
text
in
self
.
texts
:
payload
=
{
payload
=
{
"model"
:
self
.
model
,
"model"
:
self
.
model
,
...
@@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase):
...
@@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase):
)
)
def
test_embedding_based_response
(
self
):
def
test_embedding_based_response
(
self
):
"""
P
rint API response using input embeddings."""
"""
Test and p
rint API response
s
using input embeddings."""
for
text
in
self
.
texts
:
for
text
in
self
.
texts
:
embeddings
=
self
.
generate_input_embeddings
(
text
)
embeddings
=
self
.
generate_input_embeddings
(
text
)
payload
=
{
payload
=
{
...
@@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase):
...
@@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase):
)
)
def
test_compare_text_vs_embedding
(
self
):
def
test_compare_text_vs_embedding
(
self
):
"""
Print
responses for
both
text-based and embedding-based inputs."""
"""
Test and compare
responses for text-based and embedding-based inputs."""
for
text
in
self
.
texts
:
for
text
in
self
.
texts
:
# Text-based payload
# Text-based payload
text_payload
=
{
text_payload
=
{
...
@@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase):
...
@@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase):
# This is flaky, so we skip this temporarily
# This is flaky, so we skip this temporarily
# self.assertEqual(text_response["text"], embed_response["text"])
# self.assertEqual(text_response["text"], embed_response["text"])
def
test_generate_from_file
(
self
):
"""Test the /generate_from_file endpoint using tokenized embeddings."""
for
text
in
self
.
texts
:
embeddings
=
self
.
generate_input_embeddings
(
text
)
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".json"
,
delete
=
False
)
as
tmp_file
:
json
.
dump
(
embeddings
,
tmp_file
)
tmp_file_path
=
tmp_file
.
name
try
:
response
=
self
.
send_file_request
(
tmp_file_path
)
print
(
f
"Text Input:
{
text
}
\n
Response from /generate_from_file:
{
json
.
dumps
(
response
,
indent
=
2
)
}
\n
{
'-'
*
80
}
"
)
finally
:
# Ensure the temporary file is deleted
os
.
remove
(
tmp_file_path
)
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
kill_process_tree
(
cls
.
process
.
pid
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment