Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b16e856f
"vscode:/vscode.git/clone" did not exist on "f4e9ebc3eb37eda0d7ab373a35d3deb0d18e5678"
Unverified
Commit
b16e856f
authored
Aug 09, 2024
by
Ying Sheng
Committed by
GitHub
Aug 09, 2024
Browse files
Add openai embedding API (#997)
parent
05c50a82
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
135 additions
and
19 deletions
+135
-19
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+6
-4
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+21
-11
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+9
-3
python/sglang/srt/server.py
python/sglang/srt/server.py
+8
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+87
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
b16e856f
...
@@ -194,7 +194,8 @@ class EmbeddingReqInput:
...
@@ -194,7 +194,8 @@ class EmbeddingReqInput:
if
is_single
:
if
is_single
:
if
self
.
rid
is
None
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
sampling_params
=
{
"max_new_tokens"
:
0
}
if
self
.
sampling_params
is
None
:
self
.
sampling_params
=
{
"max_new_tokens"
:
1
}
else
:
else
:
# support select operation
# support select operation
self
.
batch_size
=
(
self
.
batch_size
=
(
...
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
...
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
else
:
else
:
if
not
isinstance
(
self
.
rid
,
list
):
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
raise
ValueError
(
"The rid should be a list."
)
self
.
sampling_params
=
[
if
self
.
sampling_params
is
None
:
{
"max_new_tokens"
:
0
}
for
_
in
range
(
self
.
batch_size
)
self
.
sampling_params
=
[
]
{
"max_new_tokens"
:
1
}
for
_
in
range
(
self
.
batch_size
)
]
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
b16e856f
...
@@ -262,6 +262,7 @@ class TokenizerManager:
...
@@ -262,6 +262,7 @@ class TokenizerManager:
):
):
yield
response
yield
response
else
:
else
:
assert
self
.
is_generation
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
yield
input_ids
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b16e856f
...
@@ -499,6 +499,8 @@ class ModelTpServer:
...
@@ -499,6 +499,8 @@ class ModelTpServer:
req
.
embedding
=
embeddings
[
i
]
req
.
embedding
=
embeddings
[
i
]
if
req
is
not
self
.
current_inflight_req
:
if
req
is
not
self
.
current_inflight_req
:
# Inflight reqs' prefill is not finished
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req
.
output_ids
.
append
(
0
)
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
...
...
python/sglang/srt/openai_api/adapter.py
View file @
b16e856f
...
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
...
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
generate_chat_conv
,
generate_chat_conv
,
register_conv_template
,
register_conv_template
,
)
)
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
EmbeddingReqInput
,
GenerateReqInput
from
sglang.srt.openai_api.protocol
import
(
from
sglang.srt.openai_api.protocol
import
(
BatchRequest
,
BatchRequest
,
BatchResponse
,
BatchResponse
,
...
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
...
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
CompletionResponseStreamChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
CompletionStreamResponse
,
DeltaMessage
,
DeltaMessage
,
EmbeddingObject
,
EmbeddingRequest
,
EmbeddingRequest
,
EmbeddingResponse
,
EmbeddingResponse
,
ErrorResponse
,
ErrorResponse
,
...
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
...
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
prompts
=
[]
prompts
=
[]
sampling_params_list
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
promp
t
)
first_prompt_type
=
type
(
all_requests
[
0
].
inpu
t
)
for
request
in
all_requests
:
for
request
in
all_requests
:
prompt
=
request
.
promp
t
prompt
=
request
.
inpu
t
assert
(
assert
(
type
(
prompt
)
==
first_prompt_type
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
),
"All prompts must be of the same type in file input settings"
...
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
...
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
return
adapted_request
,
all_requests
return
adapted_request
,
all_requests
def
v1_embedding_response
(
request
,
ret
,
to_file
=
False
):
def
v1_embedding_response
(
ret
,
model_path
,
to_file
=
False
):
response
=
[]
embedding_objects
=
[]
prompt_tokens
=
0
for
idx
,
ret_item
in
enumerate
(
ret
):
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
embedding_objects
.
append
(
EmbeddingResponse
(
EmbeddingObject
(
embedding
=
ret
[
idx
][
"embedding"
],
index
=
idx
,
index
=
idx
,
embedding
=
ret
[
idx
],
object
=
"embedding"
,
)
)
)
)
return
response
prompt_tokens
+=
ret
[
idx
][
"meta_info"
][
"prompt_tokens"
]
return
EmbeddingResponse
(
data
=
embedding_objects
,
model
=
model_path
,
usage
=
UsageInfo
(
prompt_tokens
=
prompt_tokens
,
total_tokens
=
prompt_tokens
,
),
)
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
...
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
...
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if
not
isinstance
(
ret
,
list
):
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
ret
=
[
ret
]
response
=
v1_embedding_response
(
re
quest
,
ret
)
response
=
v1_embedding_response
(
re
t
,
tokenizer_manager
.
model_path
)
return
response
return
response
...
...
python/sglang/srt/openai_api/protocol.py
View file @
b16e856f
...
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
...
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
class
Embedding
Response
(
BaseModel
):
class
Embedding
Object
(
BaseModel
):
index
:
str
embedding
:
List
[
float
]
embedding
:
List
[
float
]
=
None
index
:
int
object
:
str
=
"embedding"
object
:
str
=
"embedding"
class
EmbeddingResponse
(
BaseModel
):
data
:
List
[
EmbeddingObject
]
model
:
str
object
:
str
=
"list"
usage
:
Optional
[
UsageInfo
]
=
None
usage
:
Optional
[
UsageInfo
]
=
None
python/sglang/srt/server.py
View file @
b16e856f
...
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
...
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
v1_chat_completions
,
v1_chat_completions
,
v1_completions
,
v1_completions
,
v1_delete_file
,
v1_delete_file
,
v1_embeddings
,
v1_files_create
,
v1_files_create
,
v1_retrieve_batch
,
v1_retrieve_batch
,
v1_retrieve_file
,
v1_retrieve_file
,
...
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
...
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
return
await
v1_chat_completions
(
tokenizer_manager
,
raw_request
)
@
app
.
post
(
"/v1/embeddings"
)
async
def
openai_v1_embeddings
(
raw_request
:
Request
):
response
=
await
v1_embeddings
(
tokenizer_manager
,
raw_request
)
return
response
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
def
available_models
():
def
available_models
():
"""Show available models."""
"""Show available models."""
...
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
0
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
1
try
:
try
:
for
_
in
range
(
server_args
.
dp_size
):
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
res
=
requests
.
post
(
...
...
test/srt/run_suite.py
View file @
b16e856f
...
@@ -6,6 +6,7 @@ from sglang.test.test_utils import run_unittest_files
...
@@ -6,6 +6,7 @@ from sglang.test.test_utils import run_unittest_files
suites
=
{
suites
=
{
"minimal"
:
[
"minimal"
:
[
"test_eval_accuracy.py"
,
"test_eval_accuracy.py"
,
"test_embedding_openai_server.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
...
...
test/srt/test_embedding_openai_server.py
0 → 100644
View file @
b16e856f
import
json
import
time
import
unittest
import
openai
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.openai_api.protocol
import
EmbeddingObject
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
popen_launch_server
class
TestOpenAIServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"intfloat/e5-mistral-7b-instruct"
cls
.
base_url
=
"http://127.0.0.1:8157"
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
api_key
=
cls
.
api_key
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
cls
.
model
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_embedding
(
self
,
use_list_input
,
token_input
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
prompt
=
"The capital of France is"
if
token_input
:
prompt_input
=
self
.
tokenizer
.
encode
(
prompt
)
num_prompt_tokens
=
len
(
prompt_input
)
else
:
prompt_input
=
prompt
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
num_prompts
=
len
(
prompt_arg
)
else
:
prompt_arg
=
prompt_input
num_prompts
=
1
response
=
client
.
embeddings
.
create
(
input
=
prompt_arg
,
model
=
self
.
model
,
)
assert
len
(
response
.
data
)
==
num_prompts
assert
isinstance
(
response
.
data
,
list
)
assert
response
.
data
[
0
].
embedding
assert
response
.
data
[
0
].
index
is
not
None
assert
response
.
data
[
0
].
object
==
"embedding"
assert
response
.
model
==
self
.
model
assert
response
.
object
==
"list"
assert
(
response
.
usage
.
prompt_tokens
==
num_prompt_tokens
),
f
"
{
response
.
usage
.
prompt_tokens
}
vs
{
num_prompt_tokens
}
"
assert
(
response
.
usage
.
total_tokens
==
num_prompt_tokens
),
f
"
{
response
.
usage
.
total_tokens
}
vs
{
num_prompt_tokens
}
"
def
run_batch
(
self
):
# FIXME not implemented
pass
def
test_embedding
(
self
):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for
use_list_input
in
[
False
]:
for
token_input
in
[
False
,
True
]:
self
.
run_embedding
(
use_list_input
,
token_input
)
def
test_batch
(
self
):
self
.
run_batch
()
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
# t = TestOpenAIServer()
# t.setUpClass()
# t.test_embedding()
# t.tearDownClass()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment