Unverified Commit 887d7af8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Gate `prompt_embeds` behind a feature flag (#17607)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a9284245
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
import pytest import pytest
...@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str): ...@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str):
llm = LLM( llm = LLM(
model=model, model=model,
skip_tokenizer_init=True, skip_tokenizer_init=True,
enforce_eager=True,
) )
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
...@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str): ...@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str):
assert len(completions) > 0 assert len(completions) > 0
assert completions[0].text == "" assert completions[0].text == ""
assert completions[0].token_ids assert completions[0].token_ids
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_enable_prompt_embeds(hf_runner, model: str,
enable_prompt_embeds: bool):
prompt = "abc"
with hf_runner(model) as hf_model:
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
token_ids = token_ids.to(hf_model.model.device)
embed_layer = hf_model.model.get_input_embeddings()
prompt_embeds = embed_layer(token_ids).squeeze(0)
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
ValueError, match="set `--enable-prompt-embeds`"))
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
enable_prompt_embeds=enable_prompt_embeds,
enforce_eager=True,
)
with ctx:
llm.generate({"prompt_embeds": prompt_embeds})
...@@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ...@@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators # in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.") pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
"VLLM_USE_V1") == "0" else None else None)
prompt_token_ids = [] prompt_token_ids = []
for prompt in example_prompts: for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt, token_ids = hf_model.tokenizer(prompt,
...@@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ...@@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
max_num_seqs=2, max_num_seqs=2,
enable_prompt_embeds=use_prompt_embeds,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
......
...@@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): ...@@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
max_num_seqs=100000, max_num_seqs=100000,
enable_chunked_prefill=False, enable_chunked_prefill=False,
enable_prompt_embeds=True,
) )
seq_lens: list[int] = [] seq_lens: list[int] = []
...@@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): ...@@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
max_num_seqs=100000, max_num_seqs=100000,
enable_chunked_prefill=False, enable_chunked_prefill=False,
enable_prompt_embeds=True,
) )
context_lens: list[int] = [] context_lens: list[int] = []
...@@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, ...@@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
max_num_seqs=100000, max_num_seqs=100000,
enable_chunked_prefill=True, enable_chunked_prefill=True,
enable_prompt_embeds=True,
) )
# Add prefill requests. # Add prefill requests.
......
...@@ -321,6 +321,10 @@ class ModelConfig: ...@@ -321,6 +321,10 @@ class ModelConfig:
"""Skip initialization of tokenizer and detokenizer. Expects valid """Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated `prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids.""" output will contain token ids."""
enable_prompt_embeds: bool = False
"""If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required
for graph compilation."""
served_model_name: Optional[Union[str, list[str]]] = None served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the """The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the server will respond to any of the provided names. The model name in the
......
...@@ -234,6 +234,7 @@ class EngineArgs: ...@@ -234,6 +234,7 @@ class EngineArgs:
hf_config_path: Optional[str] = ModelConfig.hf_config_path hf_config_path: Optional[str] = ModelConfig.hf_config_path
task: TaskOption = ModelConfig.task task: TaskOption = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
...@@ -445,6 +446,8 @@ class EngineArgs: ...@@ -445,6 +446,8 @@ class EngineArgs:
**model_kwargs["disable_cascade_attn"]) **model_kwargs["disable_cascade_attn"])
model_group.add_argument("--skip-tokenizer-init", model_group.add_argument("--skip-tokenizer-init",
**model_kwargs["skip_tokenizer_init"]) **model_kwargs["skip_tokenizer_init"])
model_group.add_argument("--enable-prompt-embeds",
**model_kwargs["enable_prompt_embeds"])
model_group.add_argument("--served-model-name", model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"]) **model_kwargs["served_model_name"])
# This one is a special case because it is the # This one is a special case because it is the
...@@ -874,6 +877,7 @@ class EngineArgs: ...@@ -874,6 +877,7 @@ class EngineArgs:
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn, disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
......
...@@ -303,8 +303,11 @@ class InputPreprocessor: ...@@ -303,8 +303,11 @@ class InputPreprocessor:
self, self,
parsed_content: EmbedsPrompt, parsed_content: EmbedsPrompt,
) -> EmbedsInputs: ) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError("You must set `--enable-prompt-embeds` to input "
"`prompt_embeds`.")
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.") raise ValueError("`prompt_embeds` is only available in V0.")
prompt_embeds = parsed_content["prompt_embeds"] prompt_embeds = parsed_content["prompt_embeds"]
......
...@@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# product. # product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\ cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes .cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False) cudagraph_inputs_embeds = ((
True, False) if self.model_config.enable_prompt_embeds else
(False, ))
compilation_cases = itertools.product( compilation_cases = itertools.product(
cudagraph_capture_sizes, cudagraph_capture_sizes,
cudagraph_inputs_embeds, cudagraph_inputs_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment