Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67bdf8e5
Unverified
Commit
67bdf8e5
authored
Oct 29, 2024
by
Joe Runde
Committed by
GitHub
Oct 29, 2024
Browse files
[Bugfix][Frontend] Guard against bad token ids (#9634)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
0ad216f5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
89 additions
and
17 deletions
+89
-17
tests/entrypoints/llm/test_prompt_validation.py
tests/entrypoints/llm/test_prompt_validation.py
+7
-1
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+9
-9
tests/entrypoints/openai/test_prompt_validation.py
tests/entrypoints/openai/test_prompt_validation.py
+15
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+12
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+36
-4
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+5
-0
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+5
-0
No files found.
tests/entrypoints/llm/test_prompt_validation.py
View file @
67bdf8e5
...
@@ -4,6 +4,12 @@ from vllm import LLM
...
@@ -4,6 +4,12 @@ from vllm import LLM
def
test_empty_prompt
():
def
test_empty_prompt
():
llm
=
LLM
(
model
=
"gpt2"
)
llm
=
LLM
(
model
=
"gpt2"
,
enforce_eager
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
'Prompt cannot be empty'
):
with
pytest
.
raises
(
ValueError
,
match
=
'Prompt cannot be empty'
):
llm
.
generate
([
""
])
llm
.
generate
([
""
])
def
test_out_of_vocab_token
():
llm
=
LLM
(
model
=
"gpt2"
,
enforce_eager
=
True
)
with
pytest
.
raises
(
ValueError
,
match
=
'out of vocabulary'
):
llm
.
generate
({
"prompt_token_ids"
:
[
999999
]})
tests/entrypoints/openai/test_completion.py
View file @
67bdf8e5
...
@@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
...
@@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_added_lora_tokens_base_model
(
client
:
openai
.
AsyncOpenAI
):
async
def
test_added_lora_tokens_base_model
(
client
:
openai
.
AsyncOpenAI
):
# test using token IDs
# test using token IDs
completion
=
await
client
.
completions
.
create
(
with
pytest
.
raises
(
openai
.
BadRequestError
,
match
=
"out of vocabulary"
):
model
=
MODEL_NAME
,
# Added tokens should be rejected by the base model
prompt
=
[
0
,
0
,
32000
,
32001
,
32002
],
await
client
.
completions
.
create
(
echo
=
True
,
model
=
MODEL_NAME
,
max_tokens
=
5
,
prompt
=
[
0
,
0
,
32000
,
32001
,
32002
]
,
temperature
=
0.0
,
echo
=
True
,
)
max_tokens
=
5
,
# Added tokens should not appear in tokenized prompt
temperature
=
0.0
,
assert
"vllm"
not
in
completion
.
choices
[
0
].
text
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/openai/test_prompt_validation.py
View file @
67bdf8e5
...
@@ -20,3 +20,18 @@ async def test_empty_prompt():
...
@@ -20,3 +20,18 @@ async def test_empty_prompt():
prompt
=
""
,
prompt
=
""
,
max_tokens
=
5
,
max_tokens
=
5
,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
asyncio
async
def
test_out_of_vocab_token_ids
():
model_name
=
"gpt2"
server_args
=
[
"--enforce-eager"
]
with
RemoteOpenAIServer
(
model_name
,
server_args
)
as
remote_server
:
client
=
remote_server
.
get_async_client
()
with
pytest
.
raises
(
openai
.
BadRequestError
,
match
=
re
.
compile
(
'.*out of vocabulary.*'
)):
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
[
999999
],
max_tokens
=
5
,
temperature
=
0.0
)
vllm/engine/async_llm_engine.py
View file @
67bdf8e5
...
@@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -412,6 +412,12 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
async
def
get_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
AnyTokenizer
:
return
await
(
self
.
get_tokenizer_group
().
get_lora_tokenizer_async
(
lora_request
))
@
overload
# DEPRECATED
@
overload
# DEPRECATED
async
def
add_request_async
(
async
def
add_request_async
(
self
,
self
,
...
@@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -472,6 +478,10 @@ class _AsyncLLMEngine(LLMEngine):
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
self
.
tokenizer
is
not
None
:
tokenizer
=
await
self
.
get_tokenizer_async
(
lora_request
)
self
.
_validate_token_prompt
(
prompt
,
tokenizer
=
tokenizer
)
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
...
@@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -488,7 +498,7 @@ class _AsyncLLMEngine(LLMEngine):
# implementation in the LLMEngine
# implementation in the LLMEngine
params
=
await
build_guided_decoding_logits_processor_async
(
params
=
await
build_guided_decoding_logits_processor_async
(
sampling_params
=
params
,
sampling_params
=
params
,
tokenizer
=
self
.
get_tokenizer
(
lora_request
),
tokenizer
=
await
self
.
get_tokenizer
_async
(
lora_request
),
default_guided_backend
=
self
.
decoding_config
.
default_guided_backend
=
self
.
decoding_config
.
guided_decoding_backend
)
guided_decoding_backend
)
...
@@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -715,8 +725,7 @@ class AsyncLLMEngine(EngineClient):
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
)
->
AnyTokenizer
:
return
await
(
self
.
engine
.
get_tokenizer_group
().
return
await
self
.
engine
.
get_tokenizer_async
(
lora_request
)
get_lora_tokenizer_async
(
lora_request
))
def
start_background_loop
(
self
)
->
None
:
def
start_background_loop
(
self
)
->
None
:
"""Start the background loop."""
"""Start the background loop."""
...
...
vllm/engine/llm_engine.py
View file @
67bdf8e5
...
@@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
...
@@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
import
torch
import
torch
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeIs
,
TypeVar
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
...
@@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
...
@@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
EncoderDecoderInputs
,
InputRegistry
,
PromptType
)
EncoderDecoderInputs
,
InputRegistry
,
PromptType
,
TokensPrompt
)
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logits_process
import
get_bad_words_logits_processors
from
vllm.logits_process
import
get_bad_words_logits_processors
...
@@ -667,7 +668,7 @@ class LLMEngine:
...
@@ -667,7 +668,7 @@ class LLMEngine:
)
)
return
None
return
None
self
.
_validate_model_inputs
(
processed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
...
@@ -829,6 +830,11 @@ class LLMEngine:
...
@@ -829,6 +830,11 @@ class LLMEngine:
if
arrival_time
is
None
:
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
arrival_time
=
time
.
time
()
if
self
.
tokenizer
is
not
None
:
self
.
_validate_token_prompt
(
prompt
,
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
))
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
...
@@ -855,6 +861,31 @@ class LLMEngine:
...
@@ -855,6 +861,31 @@ class LLMEngine:
priority
=
priority
,
priority
=
priority
,
)
)
def
_validate_token_prompt
(
self
,
prompt
:
PromptType
,
tokenizer
:
AnyTokenizer
):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if
self
.
_is_token_prompt
(
prompt
):
prompt_ids
=
prompt
[
"prompt_token_ids"
]
if
len
(
prompt_ids
)
==
0
:
# Empty prompt check is handled later
return
max_input_id
=
max
(
prompt_ids
)
if
max_input_id
>
tokenizer
.
max_token_id
:
raise
ValueError
(
"Token id {} is out of vocabulary"
.
format
(
max_input_id
))
@
staticmethod
def
_is_token_prompt
(
prompt
:
PromptType
)
->
TypeIs
[
TokensPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"prompt_token_ids"
in
prompt
def
_create_sequence_group_with_sampling
(
def
_create_sequence_group_with_sampling
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
...
@@ -1942,7 +1973,8 @@ class LLMEngine:
...
@@ -1942,7 +1973,8 @@ class LLMEngine:
return
self
.
input_preprocessor
.
is_encoder_decoder_model
()
return
self
.
input_preprocessor
.
is_encoder_decoder_model
()
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
DecoderOnlyInputs
,
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]):
EncoderDecoderInputs
],
lora_request
:
Optional
[
LoRARequest
]):
if
self
.
model_config
.
is_multimodal_model
:
if
self
.
model_config
.
is_multimodal_model
:
# For encoder-decoder multimodal models, the max_prompt_len
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
# restricts the decoder prompt length
...
...
vllm/transformers_utils/tokenizer.py
View file @
67bdf8e5
...
@@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
...
@@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer
.
all_special_tokens_extended
)
tokenizer
.
all_special_tokens_extended
)
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
tokenizer_len
=
len
(
tokenizer
)
tokenizer_len
=
len
(
tokenizer
)
max_token_id
=
max
(
tokenizer
.
get_vocab
().
values
())
class
CachedTokenizer
(
tokenizer
.
__class__
):
# type: ignore
class
CachedTokenizer
(
tokenizer
.
__class__
):
# type: ignore
...
@@ -50,6 +51,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
...
@@ -50,6 +51,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
def
all_special_tokens_extended
(
self
):
def
all_special_tokens_extended
(
self
):
return
tokenizer_all_special_tokens_extended
return
tokenizer_all_special_tokens_extended
@
property
def
max_token_id
(
self
):
return
max_token_id
def
__len__
(
self
):
def
__len__
(
self
):
return
tokenizer_len
return
tokenizer_len
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
67bdf8e5
...
@@ -85,6 +85,7 @@ class MistralTokenizer:
...
@@ -85,6 +85,7 @@ class MistralTokenizer:
raise
TypeError
(
f
"Unsupported tokenizer:
{
type
(
tokenizer_
)
}
"
)
raise
TypeError
(
f
"Unsupported tokenizer:
{
type
(
tokenizer_
)
}
"
)
self
.
tokenizer
=
tokenizer_
self
.
tokenizer
=
tokenizer_
self
.
_max_token_id
=
max
(
self
.
_vocab
.
values
())
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
def
from_pretrained
(
cls
,
...
@@ -158,6 +159,10 @@ class MistralTokenizer:
...
@@ -158,6 +159,10 @@ class MistralTokenizer:
def
vocab_size
(
self
)
->
int
:
def
vocab_size
(
self
)
->
int
:
return
len
(
self
.
_vocab
)
return
len
(
self
.
_vocab
)
@
property
def
max_token_id
(
self
)
->
int
:
return
self
.
_max_token_id
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
self
.
vocab_size
return
self
.
vocab_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment