Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
887d7af8
Unverified
Commit
887d7af8
authored
May 04, 2025
by
Cyrus Leung
Committed by
GitHub
May 04, 2025
Browse files
[Core] Gate `prompt_embeds` behind a feature flag (#17607)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
a9284245
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
84 additions
and
4 deletions
+84
-4
tests/engine/test_options.py
tests/engine/test_options.py
+60
-0
tests/models/language/generation/test_common.py
tests/models/language/generation/test_common.py
+6
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+3
-0
vllm/config.py
vllm/config.py
+4
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-0
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+4
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-1
No files found.
tests/engine/test_
skip_tokenizer_init
.py
→
tests/engine/test_
options
.py
View file @
887d7af8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
nullcontext
import
pytest
import
pytest
...
@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str):
...
@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str):
llm
=
LLM
(
llm
=
LLM
(
model
=
model
,
model
=
model
,
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
enforce_eager
=
True
,
)
)
sampling_params
=
SamplingParams
(
prompt_logprobs
=
True
,
detokenize
=
True
)
sampling_params
=
SamplingParams
(
prompt_logprobs
=
True
,
detokenize
=
True
)
...
@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str):
...
@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str):
assert
len
(
completions
)
>
0
assert
len
(
completions
)
>
0
assert
completions
[
0
].
text
==
""
assert
completions
[
0
].
text
==
""
assert
completions
[
0
].
token_ids
assert
completions
[
0
].
token_ids
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"distilbert/distilgpt2"
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
def
test_enable_prompt_embeds
(
hf_runner
,
model
:
str
,
enable_prompt_embeds
:
bool
):
prompt
=
"abc"
with
hf_runner
(
model
)
as
hf_model
:
token_ids
=
hf_model
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
token_ids
=
token_ids
.
to
(
hf_model
.
model
.
device
)
embed_layer
=
hf_model
.
model
.
get_input_embeddings
()
prompt_embeds
=
embed_layer
(
token_ids
).
squeeze
(
0
)
ctx
=
(
nullcontext
()
if
enable_prompt_embeds
else
pytest
.
raises
(
ValueError
,
match
=
"set `--enable-prompt-embeds`"
))
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm
=
LLM
(
model
=
model
,
enable_prompt_embeds
=
enable_prompt_embeds
,
enforce_eager
=
True
,
)
with
ctx
:
llm
.
generate
({
"prompt_embeds"
:
prompt_embeds
})
tests/models/language/generation/test_common.py
View file @
887d7af8
...
@@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
...
@@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators
# in parts of the operators
pytest
.
skip
(
f
"Skipping '
{
model
}
' model test with AITER kernel."
)
pytest
.
skip
(
f
"Skipping '
{
model
}
' model test with AITER kernel."
)
use_prompt_embeds
=
os
.
getenv
(
"VLLM_USE_V1"
)
==
"0"
with
hf_runner
(
model
)
as
hf_model
:
with
hf_runner
(
model
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
prompt_embeds
:
Optional
[
list
[
torch
.
Tensor
]]
=
[]
if
os
.
getenv
(
prompt_embeds
:
Optional
[
list
[
torch
.
Tensor
]]
=
([]
if
use_prompt_embeds
"VLLM_USE_V1"
)
==
"0"
else
None
else
None
)
prompt_token_ids
=
[]
prompt_token_ids
=
[]
for
prompt
in
example_prompts
:
for
prompt
in
example_prompts
:
token_ids
=
hf_model
.
tokenizer
(
prompt
,
token_ids
=
hf_model
.
tokenizer
(
prompt
,
...
@@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
...
@@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
trust_remote_code
=
model_info
.
trust_remote_code
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
enable_prompt_embeds
=
use_prompt_embeds
,
)
as
vllm_model
:
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
...
...
tests/worker/test_model_runner.py
View file @
887d7af8
...
@@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
...
@@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens
=
100000
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
enable_prompt_embeds
=
True
,
)
)
seq_lens
:
list
[
int
]
=
[]
seq_lens
:
list
[
int
]
=
[]
...
@@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
...
@@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens
=
100000
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enable_chunked_prefill
=
False
,
enable_prompt_embeds
=
True
,
)
)
context_lens
:
list
[
int
]
=
[]
context_lens
:
list
[
int
]
=
[]
...
@@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
...
@@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
max_num_batched_tokens
=
100000
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
True
,
enable_chunked_prefill
=
True
,
enable_prompt_embeds
=
True
,
)
)
# Add prefill requests.
# Add prefill requests.
...
...
vllm/config.py
View file @
887d7af8
...
@@ -321,6 +321,10 @@ class ModelConfig:
...
@@ -321,6 +321,10 @@ class ModelConfig:
"""Skip initialization of tokenizer and detokenizer. Expects valid
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
output will contain token ids."""
enable_prompt_embeds
:
bool
=
False
"""If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required
for graph compilation."""
served_model_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
served_model_name
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
"""The model name(s) used in the API. If multiple names are provided, the
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
server will respond to any of the provided names. The model name in the
...
...
vllm/engine/arg_utils.py
View file @
887d7af8
...
@@ -234,6 +234,7 @@ class EngineArgs:
...
@@ -234,6 +234,7 @@ class EngineArgs:
hf_config_path
:
Optional
[
str
]
=
ModelConfig
.
hf_config_path
hf_config_path
:
Optional
[
str
]
=
ModelConfig
.
hf_config_path
task
:
TaskOption
=
ModelConfig
.
task
task
:
TaskOption
=
ModelConfig
.
task
skip_tokenizer_init
:
bool
=
ModelConfig
.
skip_tokenizer_init
skip_tokenizer_init
:
bool
=
ModelConfig
.
skip_tokenizer_init
enable_prompt_embeds
:
bool
=
ModelConfig
.
enable_prompt_embeds
tokenizer_mode
:
TokenizerMode
=
ModelConfig
.
tokenizer_mode
tokenizer_mode
:
TokenizerMode
=
ModelConfig
.
tokenizer_mode
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
...
@@ -445,6 +446,8 @@ class EngineArgs:
...
@@ -445,6 +446,8 @@ class EngineArgs:
**
model_kwargs
[
"disable_cascade_attn"
])
**
model_kwargs
[
"disable_cascade_attn"
])
model_group
.
add_argument
(
"--skip-tokenizer-init"
,
model_group
.
add_argument
(
"--skip-tokenizer-init"
,
**
model_kwargs
[
"skip_tokenizer_init"
])
**
model_kwargs
[
"skip_tokenizer_init"
])
model_group
.
add_argument
(
"--enable-prompt-embeds"
,
**
model_kwargs
[
"enable_prompt_embeds"
])
model_group
.
add_argument
(
"--served-model-name"
,
model_group
.
add_argument
(
"--served-model-name"
,
**
model_kwargs
[
"served_model_name"
])
**
model_kwargs
[
"served_model_name"
])
# This one is a special case because it is the
# This one is a special case because it is the
...
@@ -874,6 +877,7 @@ class EngineArgs:
...
@@ -874,6 +877,7 @@ class EngineArgs:
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_sliding_window
=
self
.
disable_sliding_window
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
disable_cascade_attn
=
self
.
disable_cascade_attn
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
skip_tokenizer_init
=
self
.
skip_tokenizer_init
,
enable_prompt_embeds
=
self
.
enable_prompt_embeds
,
served_model_name
=
self
.
served_model_name
,
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
...
...
vllm/inputs/preprocess.py
View file @
887d7af8
...
@@ -303,8 +303,11 @@ class InputPreprocessor:
...
@@ -303,8 +303,11 @@ class InputPreprocessor:
self
,
self
,
parsed_content
:
EmbedsPrompt
,
parsed_content
:
EmbedsPrompt
,
)
->
EmbedsInputs
:
)
->
EmbedsInputs
:
if
not
self
.
model_config
.
enable_prompt_embeds
:
raise
ValueError
(
"You must set `--enable-prompt-embeds` to input "
"`prompt_embeds`."
)
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"prompt_embeds is only available in V0."
)
raise
ValueError
(
"
`
prompt_embeds
`
is only available in V0."
)
prompt_embeds
=
parsed_content
[
"prompt_embeds"
]
prompt_embeds
=
parsed_content
[
"prompt_embeds"
]
...
...
vllm/worker/model_runner.py
View file @
887d7af8
...
@@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# product.
# product.
cudagraph_capture_sizes
=
self
.
vllm_config
.
compilation_config
\
cudagraph_capture_sizes
=
self
.
vllm_config
.
compilation_config
\
.
cudagraph_capture_sizes
.
cudagraph_capture_sizes
cudagraph_inputs_embeds
=
(
True
,
False
)
cudagraph_inputs_embeds
=
((
True
,
False
)
if
self
.
model_config
.
enable_prompt_embeds
else
(
False
,
))
compilation_cases
=
itertools
.
product
(
compilation_cases
=
itertools
.
product
(
cudagraph_capture_sizes
,
cudagraph_capture_sizes
,
cudagraph_inputs_embeds
,
cudagraph_inputs_embeds
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment