Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
7599bade
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "d3b838ce603863c661ae4710a733b78d46d14ca8"
Unverified
Commit
7599bade
authored
Aug 10, 2024
by
Ying Sheng
Committed by
GitHub
Aug 10, 2024
Browse files
Support embedding input as a list (#1014)
parent
62757db6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
50 deletions
+65
-50
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+59
-43
python/sglang/test/runners.py
python/sglang/test/runners.py
+3
-5
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+3
-2
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
7599bade
...
@@ -153,9 +153,7 @@ class TokenizerManager:
...
@@ -153,9 +153,7 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
yield
response
yield
response
else
:
else
:
if
isinstance
(
obj
,
EmbeddingReqInput
):
if
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
:
raise
NotImplementedError
(
"Please send only one prompt in each request"
)
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
raise
ValueError
(
"Do not support stream for batch mode."
)
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
...
@@ -283,24 +281,29 @@ class TokenizerManager:
...
@@ -283,24 +281,29 @@ class TokenizerManager:
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
):
batch_size
=
obj
.
batch_size
batch_size
=
obj
.
batch_size
parallel_sample_num
=
obj
.
parallel_sample_num
if
self
.
is_generation
:
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
# Send prefill requests to cache the common input
if
parallel_sample_num
!=
1
:
parallel_sample_num
+=
1
# Send prefill requests to cache the common input
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
parallel_sample_num
+=
1
for
i
in
range
(
batch_size
):
input_id_result
=
[]
if
obj
.
input_ids
is
None
else
None
async
for
input_id
in
self
.
_handle_single_request
(
for
i
in
range
(
batch_size
):
obj
,
request
,
index
=
i
,
is_cache_for_prefill
=
True
async
for
input_id
in
self
.
_handle_single_request
(
):
obj
,
request
,
index
=
i
,
is_cache_for_prefill
=
True
if
input_id_result
is
not
None
:
):
input_id_result
.
append
(
input_id
)
if
input_id_result
is
not
None
:
if
input_id_result
is
not
None
and
len
(
input_id_result
)
>
1
:
input_id_result
.
append
(
input_id
)
obj
.
input_ids
=
input_id_result
if
input_id_result
is
not
None
and
len
(
input_id_result
)
>
1
:
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
obj
.
input_ids
=
input_id_result
[
0
]
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
else
:
parallel_sample_num
=
1
# First send out all requests
# First send out all requests
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -329,28 +332,38 @@ class TokenizerManager:
...
@@ -329,28 +332,38 @@ class TokenizerManager:
input_text
=
None
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
tokenized_obj
=
TokenizedGenerateReqInput
(
if
self
.
is_generation
:
rid
,
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
input_text
,
obj
.
image_data
[
index
]
input_ids
,
)
pixel_values
,
image_hash
,
tokenized_obj
=
TokenizedGenerateReqInput
(
image_size
,
rid
,
sampling_params
,
input_text
,
obj
.
return_logprob
[
index
],
input_ids
,
obj
.
logprob_start_len
[
index
],
pixel_values
,
obj
.
top_logprobs_num
[
index
],
image_hash
,
obj
.
stream
,
image_size
,
)
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
)
else
:
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
# Then wait for all responses
# Then wait for all responses
output_list
=
[]
output_list
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -373,14 +386,17 @@ class TokenizerManager:
...
@@ -373,14 +386,17 @@ class TokenizerManager:
self
.
abort_request
(
rid
)
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
continue
output_list
.
append
(
if
self
.
is_generation
:
self
.
convert_logprob_style
(
output_list
.
append
(
state
.
out_list
[
-
1
],
self
.
convert_logprob_style
(
obj
.
return_logprob
[
index
],
state
.
out_list
[
-
1
],
obj
.
top_logprobs_num
[
index
],
obj
.
return_logprob
[
index
],
obj
.
return_text_in_logprobs
,
obj
.
top_logprobs_num
[
index
],
obj
.
return_text_in_logprobs
,
)
)
)
)
else
:
output_list
.
append
(
state
.
out_list
[
-
1
])
assert
state
.
finished
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
rid
]
yield
output_list
yield
output_list
...
...
python/sglang/test/runners.py
View file @
7599bade
...
@@ -219,11 +219,9 @@ class SRTRunner:
...
@@ -219,11 +219,9 @@ class SRTRunner:
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
)
else
:
else
:
logits
=
[]
response
=
self
.
runtime
.
encode
(
prompts
)
for
prompt
in
prompts
:
response
=
json
.
loads
(
response
)
response
=
self
.
runtime
.
encode
(
prompt
)
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
response
=
json
.
loads
(
response
)
logits
.
append
(
response
[
"embedding"
])
return
ModelOutput
(
embed_logits
=
logits
)
return
ModelOutput
(
embed_logits
=
logits
)
def
__enter__
(
self
):
def
__enter__
(
self
):
...
...
test/srt/test_embedding_openai_server.py
View file @
7599bade
...
@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
prompt_arg
=
[
prompt_input
]
*
2
num_prompts
=
len
(
prompt_arg
)
num_prompts
=
len
(
prompt_arg
)
num_prompt_tokens
*=
num_prompts
else
:
else
:
prompt_arg
=
prompt_input
prompt_arg
=
prompt_input
num_prompts
=
1
num_prompts
=
1
...
@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
def
test_embedding
(
self
):
def
test_embedding
(
self
):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
# TODO support use_list_input
for
use_list_input
in
[
False
]:
for
use_list_input
in
[
False
,
True
]:
for
token_input
in
[
False
,
True
]:
for
token_input
in
[
False
,
True
]:
self
.
run_embedding
(
use_list_input
,
token_input
)
self
.
run_embedding
(
use_list_input
,
token_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment