Unverified Commit 887d7af8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Gate `prompt_embeds` behind a feature flag (#17607)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a9284245
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
import pytest
......@@ -14,6 +15,7 @@ def test_skip_tokenizer_initialization(model: str):
llm = LLM(
model=model,
skip_tokenizer_init=True,
enforce_eager=True,
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
......@@ -27,3 +29,32 @@ def test_skip_tokenizer_initialization(model: str):
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_enable_prompt_embeds(hf_runner, model: str,
enable_prompt_embeds: bool):
prompt = "abc"
with hf_runner(model) as hf_model:
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
token_ids = token_ids.to(hf_model.model.device)
embed_layer = hf_model.model.get_input_embeddings()
prompt_embeds = embed_layer(token_ids).squeeze(0)
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
ValueError, match="set `--enable-prompt-embeds`"))
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
enable_prompt_embeds=enable_prompt_embeds,
enforce_eager=True,
)
with ctx:
llm.generate({"prompt_embeds": prompt_embeds})
......@@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
"VLLM_USE_V1") == "0" else None
prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
else None)
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
......@@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
max_num_seqs=2,
enable_prompt_embeds=use_prompt_embeds,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
......
......@@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enable_prompt_embeds=True,
)
seq_lens: list[int] = []
......@@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enable_prompt_embeds=True,
)
context_lens: list[int] = []
......@@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=True,
enable_prompt_embeds=True,
)
# Add prefill requests.
......
......@@ -321,6 +321,10 @@ class ModelConfig:
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
enable_prompt_embeds: bool = False
"""If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required
for graph compilation."""
served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
......
......@@ -234,6 +234,7 @@ class EngineArgs:
hf_config_path: Optional[str] = ModelConfig.hf_config_path
task: TaskOption = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
......@@ -445,6 +446,8 @@ class EngineArgs:
**model_kwargs["disable_cascade_attn"])
model_group.add_argument("--skip-tokenizer-init",
**model_kwargs["skip_tokenizer_init"])
model_group.add_argument("--enable-prompt-embeds",
**model_kwargs["enable_prompt_embeds"])
model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"])
# This one is a special case because it is the
......@@ -874,6 +877,7 @@ class EngineArgs:
disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
......
......@@ -303,8 +303,11 @@ class InputPreprocessor:
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError("You must set `--enable-prompt-embeds` to input "
"`prompt_embeds`.")
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
raise ValueError("`prompt_embeds` is only available in V0.")
prompt_embeds = parsed_content["prompt_embeds"]
......
......@@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False)
cudagraph_inputs_embeds = ((
True, False) if self.model_config.enable_prompt_embeds else
(False, ))
compilation_cases = itertools.product(
cudagraph_capture_sizes,
cudagraph_inputs_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment