Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7599bade
Unverified
Commit
7599bade
authored
Aug 10, 2024
by
Ying Sheng
Committed by
GitHub
Aug 10, 2024
Browse files
Support embedding input as a list (#1014)
parent
62757db6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
50 deletions
+65
-50
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+59
-43
python/sglang/test/runners.py
python/sglang/test/runners.py
+3
-5
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+3
-2
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
7599bade
...
...
@@ -153,9 +153,7 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
yield
response
else
:
if
isinstance
(
obj
,
EmbeddingReqInput
):
raise
NotImplementedError
(
"Please send only one prompt in each request"
)
if
obj
.
stream
:
if
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
...
...
@@ -283,24 +281,29 @@ class TokenizerManager:
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
):
batch_size
=
obj
.
batch_size
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common input
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
is_cache_for_prefill
=
True
):
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
if
input_id_result
is
not
None
and
len
(
input_id_result
)
>
1
:
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
if
self
.
is_generation
:
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common input
parallel_sample_num
+=
1
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
for
i
in
range
(
batch_size
):
async
for
input_id
in
self
.
_handle_single_request
(
obj
,
request
,
index
=
i
,
is_cache_for_prefill
=
True
):
if
input_id_result
is
not
None
:
input_id_result
.
append
(
input_id
)
if
input_id_result
is
not
None
and
len
(
input_id_result
)
>
1
:
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
else
:
parallel_sample_num
=
1
# First send out all requests
for
i
in
range
(
batch_size
):
...
...
@@ -329,28 +332,38 @@ class TokenizerManager:
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
input_text
,
input_ids
,
pixel_values
,
image_hash
,
image_size
,
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
)
if
self
.
is_generation
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
input_text
,
input_ids
,
pixel_values
,
image_hash
,
image_size
,
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
)
else
:
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
# Then wait for all responses
output_list
=
[]
for
i
in
range
(
batch_size
):
...
...
@@ -373,14 +386,17 @@ class TokenizerManager:
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
output_list
.
append
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
return_text_in_logprobs
,
if
self
.
is_generation
:
output_list
.
append
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
obj
.
return_logprob
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
return_text_in_logprobs
,
)
)
)
else
:
output_list
.
append
(
state
.
out_list
[
-
1
])
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
yield
output_list
...
...
python/sglang/test/runners.py
View file @
7599bade
...
...
@@ -219,11 +219,9 @@ class SRTRunner:
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
else
:
logits
=
[]
for
prompt
in
prompts
:
response
=
self
.
runtime
.
encode
(
prompt
)
response
=
json
.
loads
(
response
)
logits
.
append
(
response
[
"embedding"
])
response
=
self
.
runtime
.
encode
(
prompts
)
response
=
json
.
loads
(
response
)
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
return
ModelOutput
(
embed_logits
=
logits
)
def
__enter__
(
self
):
...
...
test/srt/test_embedding_openai_server.py
View file @
7599bade
...
...
@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
prompt_arg
=
[
prompt_input
]
*
2
num_prompts
=
len
(
prompt_arg
)
num_prompt_tokens
*=
num_prompts
else
:
prompt_arg
=
prompt_input
num_prompts
=
1
...
...
@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
def
test_embedding
(
self
):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
for
use_list_input
in
[
False
]:
for
use_list_input
in
[
False
,
True
]:
for
token_input
in
[
False
,
True
]:
self
.
run_embedding
(
use_list_input
,
token_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment