Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7599bade
"integration-tests/vscode:/vscode.git/clone" did not exist on "06c3d4b1eccbfa9134082d45bc69afa6c43a3e2f"
Unverified
Commit
7599bade
authored
Aug 10, 2024
by
Ying Sheng
Committed by
GitHub
Aug 10, 2024
Browse files
Support embedding input as a list (#1014)
parent
62757db6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
50 deletions
+65
-50
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+59
-43
python/sglang/test/runners.py
python/sglang/test/runners.py
+3
-5
test/srt/test_embedding_openai_server.py
test/srt/test_embedding_openai_server.py
+3
-2
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
7599bade
...
@@ -153,9 +153,7 @@ class TokenizerManager:
...
@@ -153,9 +153,7 @@ class TokenizerManager:
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
async
for
response
in
self
.
_handle_single_request
(
obj
,
request
):
yield
response
yield
response
else
:
else
:
if
isinstance
(
obj
,
EmbeddingReqInput
):
if
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
:
raise
NotImplementedError
(
"Please send only one prompt in each request"
)
if
obj
.
stream
:
raise
ValueError
(
"Do not support stream for batch mode."
)
raise
ValueError
(
"Do not support stream for batch mode."
)
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
async
for
response
in
self
.
_handle_batch_request
(
obj
,
request
):
...
@@ -283,8 +281,11 @@ class TokenizerManager:
...
@@ -283,8 +281,11 @@ class TokenizerManager:
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
await
self
.
_wait_for_cache_prefill_response
(
event
,
state
,
obj
,
rid
,
request
)
yield
input_ids
yield
input_ids
async
def
_handle_batch_request
(
self
,
obj
:
GenerateReqInput
,
request
):
async
def
_handle_batch_request
(
self
,
obj
:
Union
[
GenerateReqInput
,
EmbeddingReqInput
],
request
):
batch_size
=
obj
.
batch_size
batch_size
=
obj
.
batch_size
if
self
.
is_generation
:
parallel_sample_num
=
obj
.
parallel_sample_num
parallel_sample_num
=
obj
.
parallel_sample_num
if
parallel_sample_num
!=
1
:
if
parallel_sample_num
!=
1
:
...
@@ -301,6 +302,8 @@ class TokenizerManager:
...
@@ -301,6 +302,8 @@ class TokenizerManager:
obj
.
input_ids
=
input_id_result
obj
.
input_ids
=
input_id_result
elif
input_id_result
is
not
None
:
elif
input_id_result
is
not
None
:
obj
.
input_ids
=
input_id_result
[
0
]
obj
.
input_ids
=
input_id_result
[
0
]
else
:
parallel_sample_num
=
1
# First send out all requests
# First send out all requests
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -329,6 +332,8 @@ class TokenizerManager:
...
@@ -329,6 +332,8 @@ class TokenizerManager:
input_text
=
None
input_text
=
None
input_ids
=
obj
.
input_ids
[
i
]
input_ids
=
obj
.
input_ids
[
i
]
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
obj
.
image_data
[
index
]
)
)
...
@@ -346,11 +351,19 @@ class TokenizerManager:
...
@@ -346,11 +351,19 @@ class TokenizerManager:
obj
.
top_logprobs_num
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
obj
.
stream
,
)
)
else
:
tokenized_obj
=
TokenizedEmbeddingReqInput
(
rid
,
input_text
,
input_ids
,
sampling_params
,
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
self
.
send_to_router
.
send_pyobj
(
tokenized_obj
)
event
=
asyncio
.
Event
()
event
=
asyncio
.
Event
()
state
=
ReqState
([],
False
,
event
)
state
=
ReqState
([],
False
,
event
)
self
.
rid_to_state
[
rid
]
=
state
self
.
rid_to_state
[
rid
]
=
state
# Then wait for all responses
# Then wait for all responses
output_list
=
[]
output_list
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
@@ -373,6 +386,7 @@ class TokenizerManager:
...
@@ -373,6 +386,7 @@ class TokenizerManager:
self
.
abort_request
(
rid
)
self
.
abort_request
(
rid
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
raise
ValueError
(
f
"Abort request
{
rid
}
"
)
continue
continue
if
self
.
is_generation
:
output_list
.
append
(
output_list
.
append
(
self
.
convert_logprob_style
(
self
.
convert_logprob_style
(
state
.
out_list
[
-
1
],
state
.
out_list
[
-
1
],
...
@@ -381,6 +395,8 @@ class TokenizerManager:
...
@@ -381,6 +395,8 @@ class TokenizerManager:
obj
.
return_text_in_logprobs
,
obj
.
return_text_in_logprobs
,
)
)
)
)
else
:
output_list
.
append
(
state
.
out_list
[
-
1
])
assert
state
.
finished
assert
state
.
finished
del
self
.
rid_to_state
[
rid
]
del
self
.
rid_to_state
[
rid
]
yield
output_list
yield
output_list
...
...
python/sglang/test/runners.py
View file @
7599bade
...
@@ -219,11 +219,9 @@ class SRTRunner:
...
@@ -219,11 +219,9 @@ class SRTRunner:
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
)
)
else
:
else
:
logits
=
[]
response
=
self
.
runtime
.
encode
(
prompts
)
for
prompt
in
prompts
:
response
=
self
.
runtime
.
encode
(
prompt
)
response
=
json
.
loads
(
response
)
response
=
json
.
loads
(
response
)
logits
.
append
(
response
[
"embedding"
])
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
return
ModelOutput
(
embed_logits
=
logits
)
return
ModelOutput
(
embed_logits
=
logits
)
def
__enter__
(
self
):
def
__enter__
(
self
):
...
...
test/srt/test_embedding_openai_server.py
View file @
7599bade
...
@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase):
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
num_prompt_tokens
=
len
(
self
.
tokenizer
.
encode
(
prompt
))
if
use_list_input
:
if
use_list_input
:
prompt_arg
=
[
prompt_input
,
prompt_input
]
prompt_arg
=
[
prompt_input
]
*
2
num_prompts
=
len
(
prompt_arg
)
num_prompts
=
len
(
prompt_arg
)
num_prompt_tokens
*=
num_prompts
else
:
else
:
prompt_arg
=
prompt_input
prompt_arg
=
prompt_input
num_prompts
=
1
num_prompts
=
1
...
@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
...
@@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase):
def
test_embedding
(
self
):
def
test_embedding
(
self
):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
# TODO support use_list_input
for
use_list_input
in
[
False
]:
for
use_list_input
in
[
False
,
True
]:
for
token_input
in
[
False
,
True
]:
for
token_input
in
[
False
,
True
]:
self
.
run_embedding
(
use_list_input
,
token_input
)
self
.
run_embedding
(
use_list_input
,
token_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment