Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c54269d9
Unverified
Commit
c54269d9
authored
Jun 26, 2024
by
sasha0552
Committed by
GitHub
Jun 26, 2024
Browse files
[Frontend] Add tokenize/detokenize endpoints (#5054)
parent
5bfd1bbc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
143 additions
and
6 deletions
+143
-6
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+49
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+30
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+21
-0
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+31
-1
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+12
-4
No files found.
tests/entrypoints/test_openai_server.py
View file @
c54269d9
...
...
@@ -9,6 +9,7 @@ import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import
ray
import
requests
import
torch
# downloading lora to test lora requests
from
huggingface_hub
import
snapshot_download
...
...
@@ -1366,5 +1367,53 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or
"less_than_equal"
in
exc_info
.
value
.
message
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_tokenize
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
]
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
,
tokenizer_mode
=
"fast"
)
for
add_special
in
[
False
,
True
]:
prompt
=
"This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
"add_special_tokens"
:
add_special
,
"model"
:
model_name
,
"prompt"
:
prompt
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"tokens"
:
tokens
,
"count"
:
len
(
tokens
),
"max_model_len"
:
8192
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_detokenize
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
base_url
=
str
(
client
.
base_url
)[:
-
3
]
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
,
tokenizer_mode
=
"fast"
)
prompt
=
"This is a test prompt."
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
response
=
requests
.
post
(
base_url
+
"detokenize"
,
json
=
{
"model"
:
model_name
,
"tokens"
:
tokens
})
response
.
raise_for_status
()
assert
response
.
json
()
==
{
"prompt"
:
prompt
}
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
vllm/entrypoints/openai/api_server.py
View file @
c54269d9
...
...
@@ -19,10 +19,17 @@ import vllm.envs as envs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
CompletionRequest
,
EmbeddingRequest
,
ErrorResponse
)
DetokenizeRequest
,
DetokenizeResponse
,
EmbeddingRequest
,
ErrorResponse
,
TokenizeRequest
,
TokenizeResponse
)
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
...
...
@@ -85,6 +92,28 @@ async def health() -> Response:
return
Response
(
status_code
=
200
)
@
app
.
post
(
"/tokenize"
)
async
def
tokenize
(
request
:
TokenizeRequest
):
generator
=
await
openai_serving_completion
.
create_tokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
assert
isinstance
(
generator
,
TokenizeResponse
)
return
JSONResponse
(
content
=
generator
.
model_dump
())
@
app
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
generator
=
await
openai_serving_completion
.
create_detokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
assert
isinstance
(
generator
,
DetokenizeResponse
)
return
JSONResponse
(
content
=
generator
.
model_dump
())
@
app
.
get
(
"/v1/models"
)
async
def
show_available_models
():
models
=
await
openai_serving_chat
.
show_available_models
()
...
...
vllm/entrypoints/openai/protocol.py
View file @
c54269d9
...
...
@@ -699,3 +699,24 @@ class BatchRequestOutput(OpenAIBaseModel):
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error
:
Optional
[
Any
]
class
TokenizeRequest
(
OpenAIBaseModel
):
model
:
str
prompt
:
str
add_special_tokens
:
bool
=
Field
(
default
=
True
)
class
TokenizeResponse
(
OpenAIBaseModel
):
tokens
:
List
[
int
]
count
:
int
max_model_len
:
int
class
DetokenizeRequest
(
OpenAIBaseModel
):
model
:
str
tokens
:
List
[
int
]
class
DetokenizeResponse
(
OpenAIBaseModel
):
prompt
:
str
vllm/entrypoints/openai/serving_completion.py
View file @
c54269d9
...
...
@@ -16,7 +16,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
UsageInfo
)
DetokenizeRequest
,
DetokenizeResponse
,
TokenizeRequest
,
TokenizeResponse
,
UsageInfo
)
# yapf: enable
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
OpenAIServing
)
from
vllm.logger
import
init_logger
...
...
@@ -442,3 +446,29 @@ class OpenAIServingCompletion(OpenAIServing):
tokens
=
out_tokens
,
top_logprobs
=
out_top_logprobs
,
)
async
def
create_tokenize
(
self
,
request
:
TokenizeRequest
)
->
TokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt
=
request
.
prompt
,
add_special_tokens
=
request
.
add_special_tokens
)
return
TokenizeResponse
(
tokens
=
input_ids
,
count
=
len
(
input_ids
),
max_model_len
=
self
.
max_model_len
)
async
def
create_detokenize
(
self
,
request
:
DetokenizeRequest
)
->
DetokenizeResponse
:
error_check_ret
=
await
self
.
_check_model
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
(
input_ids
,
input_text
)
=
self
.
_validate_prompt_and_tokenize
(
request
,
prompt_ids
=
request
.
tokens
)
return
DetokenizeResponse
(
prompt
=
input_text
)
vllm/entrypoints/openai/serving_engine.py
View file @
c54269d9
...
...
@@ -10,9 +10,10 @@ from vllm.config import ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
ErrorResponse
,
ModelCard
,
ModelList
,
ModelPermission
)
ModelPermission
,
TokenizeRequest
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
...
...
@@ -99,8 +100,9 @@ class OpenAIServing:
return
json_str
async
def
_check_model
(
self
,
request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
,
EmbeddingRequest
]
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
]
)
->
Optional
[
ErrorResponse
]:
if
request
.
model
in
self
.
served_model_names
:
return
None
...
...
@@ -126,7 +128,8 @@ class OpenAIServing:
def
_validate_prompt_and_tokenize
(
self
,
request
:
Union
[
ChatCompletionRequest
,
CompletionRequest
,
EmbeddingRequest
],
DetokenizeRequest
,
EmbeddingRequest
,
TokenizeRequest
],
prompt
:
Optional
[
str
]
=
None
,
prompt_ids
:
Optional
[
List
[
int
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
...
...
@@ -174,6 +177,11 @@ class OpenAIServing:
f
"generation. Please reduce the length of the input."
,
)
return
input_ids
,
input_text
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if
isinstance
(
request
,
(
TokenizeRequest
,
DetokenizeRequest
)):
return
input_ids
,
input_text
if
request
.
max_tokens
is
None
:
if
token_num
>=
self
.
max_model_len
:
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment