Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b16e856f
Unverified
Commit
b16e856f
authored
Aug 09, 2024
by
Ying Sheng
Committed by
GitHub
Aug 09, 2024
Browse files
Add openai embedding API (#997)
parent
05c50a82
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
135 additions
and
19 deletions
+135
-19
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+6
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+21
-11
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+9
-3
python/sglang/srt/server.py
python/sglang/srt/server.py
+8
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+87
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
b16e856f
...
...
@@ -194,7 +194,8 @@ class EmbeddingReqInput:
if
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
sampling_params
=
{
"max_new_tokens"
:
0
}
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{
"max_new_tokens"
:
1
}
else
:
# support select operation
self
.
batch_size
=
(
...
...
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
else
:
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
self
.
sampling_params
=
[
{
"max_new_tokens"
:
0
}
for
_
in
range
(
self
.
batch_size
)
]
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
[
{
"max_new_tokens"
:
1
}
for
_
in
range
(
self
.
batch_size
)
]
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
b16e856f
...
...
@@ -262,6 +262,7 @@ class TokenizerManager:
):
yield
response
else
:
assert
self
.
is_generation
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b16e856f
...
...
@@ -499,6 +499,8 @@ class ModelTpServer:
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
if
req
.
finished
():
...
...
python/sglang/srt/openai_api/adapter.py
View file @
b16e856f
...
...
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
generate_chat_conv
,
register_conv_template
,
)
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchResponse
,
...
...
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
EmbeddingObject
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
...
...
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
prompts
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
promp
t
)
first_prompt_type
=
type
(
all_requests
[
0
].
inpu
t
)
for
request
in
all_requests
:
prompt
=
request
.
promp
t
prompt
=
request
.
inpu
t
assert
(
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
...
...
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
return
adapted_request
,
all_requests
def
v1_embedding_response
(
request
,
ret
,
to_file
=
False
):
response
=
[]
def
v1_embedding_response
(
ret
,
model_path
,
to_file
=
False
):
embedding_objects
=
[]
prompt_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
EmbeddingResponse
(
embedding_objects
.
append
(
EmbeddingObject
(
embedding
=
ret
[
idx
][
"embedding"
],
index
=
idx
,
embedding
=
ret
[
idx
],
object
=
"embedding"
,
)
)
return
response
prompt_tokens
+=
ret
[
idx
][
"meta_info"
][
"prompt_tokens"
]
return
EmbeddingResponse
(
data
=
embedding_objects
,
model
=
model_path
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
),
)
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
...
...
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_embedding_response
(
re
quest
,
ret
)
response
=
v1_embedding_response
(
re
t
,
tokenizer_manager
.
model_path
)
return
response
...
...
python/sglang/srt/openai_api/protocol.py
View file @
b16e856f
...
...
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
user
:
Optional
[
str
]
=
None
class
Embedding
Response
(
BaseModel
):
index
:
str
embedding
:
List
[
float
]
=
None
class
Embedding
Object
(
BaseModel
):
embedding
:
List
[
float
]
index
:
int
object
:
str
=
"embedding"
class
EmbeddingResponse
(
BaseModel
):
data
:
List
[
EmbeddingObject
]
model
:
str
object
:
str
=
"list"
usage
:
Optional
[
UsageInfo
]
=
None
python/sglang/srt/server.py
View file @
b16e856f
...
...
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
v1_chat_completions
,
v1_completions
,
v1_delete_file
,
v1_embeddings
,
v1_files_create
,
v1_retrieve_batch
,
v1_retrieve_file
,
...
...
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
@
app
.
post
(
"/v1/embeddings"
)
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
response
=
await
v1_embeddings
(
tokenizer_manager
,
raw_request
)
return
response
@
app
.
get
(
"/v1/models"
)
def
available_models
():
"""Show available models."""
...
...
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
0
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
1
try
:
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
...
...
test/srt/run_suite.py
View file @
b16e856f
...
...
@@ -6,6 +6,7 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
"minimal"
:
[
"test_eval_accuracy.py"
,
"test_embedding_openai_server.py"
,
"test_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_chunked_prefill.py"
,
...
...
test/srt/test_embedding_openai_server.py
0 → 100644
View file @
b16e856f
import
json
import
time
import
unittest
import
openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.openai_api.protocol
import
EmbeddingObject
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
popen_launch_server
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"intfloat/e5-mistral-7b-instruct"
cls
.
base_url
=
"http://127.0.0.1:8157"
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
api_key
=
cls
.
api_key
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
cls
.
model
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_embedding
(
self
,
use_list_input
,
token_input
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
prompt
=
"The capital of France is"
if
token_input
:
prompt_input
=
self
.
tokenizer
.
encode
(
prompt
)
num_prompt_tokens
=
len
(
prompt_input
)
else
:
prompt_input
=
prompt
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
num_prompts
=
len
(
prompt_arg
)
else
:
prompt_arg
=
prompt_input
num_prompts
=
1
response
=
client
.
embeddings
.
create
(
input
=
prompt_arg
,
model
=
self
.
model
,
)
assert
len
(
response
.
data
)
==
num_prompts
assert
isinstance
(
response
.
data
,
list
)
assert
response
.
data
[
0
].
embedding
assert
response
.
data
[
0
].
index
is
not
None
assert
response
.
data
[
0
].
object
==
"embedding"
assert
response
.
model
==
self
.
model
assert
response
.
object
==
"list"
assert
(
response
.
usage
.
prompt_tokens
==
num_prompt_tokens
),
f
"
{
response
.
usage
.
prompt_tokens
}
vs
{
num_prompt_tokens
}
"
assert
(
response
.
usage
.
total_tokens
==
num_prompt_tokens
),
f
"
{
response
.
usage
.
total_tokens
}
vs
{
num_prompt_tokens
}
"
def
run_batch
(
self
):
# FIXME not implemented
pass
def
test_embedding
(
self
):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for
use_list_input
in
[
False
]:
for
token_input
in
[
False
,
True
]:
self
.
run_embedding
(
use_list_input
,
token_input
)
def
test_batch
(
self
):
self
.
run_batch
()
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
# t = TestOpenAIServer()
# t.setUpClass()
# t.test_embedding()
# t.tearDownClass()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment