Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b16e856f
Unverified
Commit
b16e856f
authored
Aug 09, 2024
by
Ying Sheng
Committed by
GitHub
Aug 09, 2024
Browse files
Add openai embedding API (#997)
parent
05c50a82
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
135 additions
and
19 deletions
+135
-19
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+6
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+21
-11
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+9
-3
python/sglang/srt/server.py
python/sglang/srt/server.py
+8
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+87
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
b16e856f
...
@@ -194,7 +194,8 @@ class EmbeddingReqInput:
...
@@ -194,7 +194,8 @@ class EmbeddingReqInput:
if
is_single
:
if
is_single
:
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
sampling_params
=
{
"max_new_tokens"
:
0
}
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{
"max_new_tokens"
:
1
}
else
:
else
:
# support select operation
# support select operation
self
.
batch_size
=
(
self
.
batch_size
=
(
...
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
...
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
else
:
else
:
if
not
isinstance
(
self
.
rid
,
list
):
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
raise
ValueError
(
"The rid should be a list."
)
self
.
sampling_params
=
[
if
self
.
sampling_params
is
None
:
{
"max_new_tokens"
:
0
}
for
_
in
range
(
self
.
batch_size
)
self
.
sampling_params
=
[
]
{
"max_new_tokens"
:
1
}
for
_
in
range
(
self
.
batch_size
)
]
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
b16e856f
...
@@ -262,6 +262,7 @@ class TokenizerManager:
...
@@ -262,6 +262,7 @@ class TokenizerManager:
):
):
yield
response
yield
response
else
:
else
:
assert
self
.
is_generation
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
yield
input_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b16e856f
...
@@ -499,6 +499,8 @@ class ModelTpServer:
...
@@ -499,6 +499,8 @@ class ModelTpServer:
req
.
embedding
=
embeddings
[
i
]
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
...
...
python/sglang/srt/openai_api/adapter.py
View file @
b16e856f
...
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
...
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
generate_chat_conv
,
generate_chat_conv
,
register_conv_template
,
register_conv_template
,
)
)
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.openai_api.protocol
import
(
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchRequest
,
BatchResponse
,
BatchResponse
,
...
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
...
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
DeltaMessage
,
DeltaMessage
,
EmbeddingObject
,
EmbeddingRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
ErrorResponse
,
ErrorResponse
,
...
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
prompts
=
[]
prompts
=
[]
sampling_params_list
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
promp
t
)
first_prompt_type
=
type
(
all_requests
[
0
].
inpu
t
)
for
request
in
all_requests
:
for
request
in
all_requests
:
prompt
=
request
.
promp
t
prompt
=
request
.
inpu
t
assert
(
assert
(
type
(
prompt
)
==
first_prompt_type
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
),
"All prompts must be of the same type in file input settings"
...
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
...
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
return
adapted_request
,
all_requests
return
adapted_request
,
all_requests
def
v1_embedding_response
(
request
,
ret
,
to_file
=
False
):
def
v1_embedding_response
(
ret
,
model_path
,
to_file
=
False
):
response
=
[]
embedding_objects
=
[]
prompt_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
embedding_objects
.
append
(
EmbeddingResponse
(
EmbeddingObject
(
embedding
=
ret
[
idx
][
"embedding"
],
index
=
idx
,
index
=
idx
,
embedding
=
ret
[
idx
],
object
=
"embedding"
,
)
)
)
)
return
response
prompt_tokens
+=
ret
[
idx
][
"meta_info"
][
"prompt_tokens"
]
return
EmbeddingResponse
(
data
=
embedding_objects
,
model
=
model_path
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
),
)
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
...
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
...
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
response
=
v1_embedding_response
(
re
quest
,
ret
)
response
=
v1_embedding_response
(
re
t
,
tokenizer_manager
.
model_path
)
return
response
return
response
...
...
python/sglang/srt/openai_api/protocol.py
View file @
b16e856f
...
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
...
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
class
Embedding
Response
(
BaseModel
):
class
Embedding
Object
(
BaseModel
):
index
:
str
embedding
:
List
[
float
]
embedding
:
List
[
float
]
=
None
index
:
int
object
:
str
=
"embedding"
object
:
str
=
"embedding"
class
EmbeddingResponse
(
BaseModel
):
data
:
List
[
EmbeddingObject
]
model
:
str
object
:
str
=
"list"
usage
:
Optional
[
UsageInfo
]
=
None
usage
:
Optional
[
UsageInfo
]
=
None
python/sglang/srt/server.py
View file @
b16e856f
...
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
...
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
v1_delete_file
,
v1_delete_file
,
v1_embeddings
,
v1_files_create
,
v1_files_create
,
v1_retrieve_batch
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file
,
...
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
@
app
.
post
(
"/v1/embeddings"
)
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
response
=
await
v1_embeddings
(
tokenizer_manager
,
raw_request
)
return
response
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
def
available_models
():
def
available_models
():
"""Show available models."""
"""Show available models."""
...
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
0
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
1
try
:
try
:
for
_
in
range
(
server_args
.
dp_size
):
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
res
=
requests
.
post
(
...
...
test/srt/run_suite.py
View file @
b16e856f
...
@@ -6,6 +6,7 @@ from sglang.test.test_utils import run_unittest_files
...
@@ -6,6 +6,7 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
suites
=
{
"minimal"
:
[
"minimal"
:
[
"test_eval_accuracy.py"
,
"test_eval_accuracy.py"
,
"test_embedding_openai_server.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
...
...
test/srt/test_embedding_openai_server.py
0 → 100644
View file @
b16e856f
import
json
import
time
import
unittest
import
openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.openai_api.protocol
import
EmbeddingObject
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
popen_launch_server
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"intfloat/e5-mistral-7b-instruct"
cls
.
base_url
=
"http://127.0.0.1:8157"
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
api_key
=
cls
.
api_key
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
cls
.
model
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_embedding
(
self
,
use_list_input
,
token_input
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
prompt
=
"The capital of France is"
if
token_input
:
prompt_input
=
self
.
tokenizer
.
encode
(
prompt
)
num_prompt_tokens
=
len
(
prompt_input
)
else
:
prompt_input
=
prompt
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
num_prompts
=
len
(
prompt_arg
)
else
:
prompt_arg
=
prompt_input
num_prompts
=
1
response
=
client
.
embeddings
.
create
(
input
=
prompt_arg
,
model
=
self
.
model
,
)
assert
len
(
response
.
data
)
==
num_prompts
assert
isinstance
(
response
.
data
,
list
)
assert
response
.
data
[
0
].
embedding
assert
response
.
data
[
0
].
index
is
not
None
assert
response
.
data
[
0
].
object
==
"embedding"
assert
response
.
model
==
self
.
model
assert
response
.
object
==
"list"
assert
(
response
.
usage
.
prompt_tokens
==
num_prompt_tokens
),
f
"
{
response
.
usage
.
prompt_tokens
}
vs
{
num_prompt_tokens
}
"
assert
(
response
.
usage
.
total_tokens
==
num_prompt_tokens
),
f
"
{
response
.
usage
.
total_tokens
}
vs
{
num_prompt_tokens
}
"
def
run_batch
(
self
):
# FIXME not implemented
pass
def
test_embedding
(
self
):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for
use_list_input
in
[
False
]:
for
token_input
in
[
False
,
True
]:
self
.
run_embedding
(
use_list_input
,
token_input
)
def
test_batch
(
self
):
self
.
run_batch
()
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
# t = TestOpenAIServer()
# t.setUpClass()
# t.test_embedding()
# t.tearDownClass()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment