Unverified Commit 9a4600e4 authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

[CORE] Prompt Embeddings Support for v1 Engine (#24278)


Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
Signed-off-by: default avatarAndrew Sansom <qthequartermasterman@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 9fac6aa3
...@@ -76,11 +76,6 @@ def test_models( ...@@ -76,11 +76,6 @@ def test_models(
model_executor: str, model_executor: str,
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
if async_scheduling: if async_scheduling:
pytest.skip("async_scheduling only supported in v1.") pytest.skip("async_scheduling only supported in v1.")
...@@ -164,11 +159,6 @@ def test_models_distributed( ...@@ -164,11 +159,6 @@ def test_models_distributed(
extra_env: dict[str, str], extra_env: dict[str, str],
enable_prompt_embeds: bool, enable_prompt_embeds: bool,
) -> None: ) -> None:
if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")
if test_suite != TARGET_TEST_SUITE: if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}") pytest.skip(f"Skip test for {test_suite}")
......
...@@ -36,7 +36,6 @@ def default_server_args() -> list[str]: ...@@ -36,7 +36,6 @@ def default_server_args() -> list[str]:
"--enforce-eager", "--enforce-eager",
# Prompt Embeds server args # Prompt Embeds server args
"--enable-prompt-embeds", "--enable-prompt-embeds",
"--no-enable-chunked-prefill",
] ]
......
...@@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ...@@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators # in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.") pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
# Note: can be removed when
# https://github.com/vllm-project/vllm/pull/24278 finished
if current_platform.is_cpu() and use_prompt_embeds:
pytest.skip("Skipping use_prompt_embeds=True with "
"V1-only CPU backend.")
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
......
...@@ -1513,12 +1513,6 @@ class EngineArgs: ...@@ -1513,12 +1513,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No text embedding inputs so far.
if self.enable_prompt_embeds:
_raise_or_fallback(feature_name="--enable-prompt-embeds",
recommend_to_remove=False)
return False
# No Mamba or Encoder-Decoder so far. # No Mamba or Encoder-Decoder so far.
if not model_config.is_v1_compatible: if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures, _raise_or_fallback(feature_name=model_config.architectures,
...@@ -1651,6 +1645,13 @@ class EngineArgs: ...@@ -1651,6 +1645,13 @@ class EngineArgs:
"models in V0 and has been disabled.") "models in V0 and has been disabled.")
self.enable_prefix_caching = False self.enable_prefix_caching = False
if self.enable_prompt_embeds:
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V0. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False
# Set max_num_seqs to 256 for VLLM_V0. # Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None: if self.max_num_seqs is None:
self.max_num_seqs = 256 self.max_num_seqs = 256
...@@ -1664,6 +1665,17 @@ class EngineArgs: ...@@ -1664,6 +1665,17 @@ class EngineArgs:
# For pooling tasks the default is False # For pooling tasks the default is False
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if (self.enable_prompt_embeds
and self.enable_prefix_caching is not False):
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled.")
self.enable_prefix_caching = False
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
self.enable_prefix_caching = True self.enable_prefix_caching = True
else: else:
......
...@@ -973,7 +973,6 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -973,7 +973,6 @@ class CompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None model: Optional[str] = None
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
best_of: Optional[int] = None best_of: Optional[int] = None
echo: Optional[bool] = False echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
...@@ -1009,6 +1008,7 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1009,6 +1008,7 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-sampling-params] # --8<-- [end:completion-sampling-params]
# --8<-- [start:completion-extra-params] # --8<-- [start:completion-extra-params]
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=True, default=True,
description=( description=(
......
...@@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None: ...@@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
pid = os.getpid() pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid) _add_prefix(sys.stderr, process_name, pid)
def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
) -> int:
"""Calculate the request length (in number of tokens) give either
prompt_token_ids or prompt_embeds.
"""
prompt_token_len = None if prompt_token_ids is None else len(
prompt_token_ids)
prompt_embeds_len = \
None if prompt_embeds is None else len(prompt_embeds)
if prompt_token_len is None:
if prompt_embeds_len is None:
raise ValueError(
"Neither prompt_token_ids nor prompt_embeds were defined.")
return prompt_embeds_len
else:
if (prompt_embeds_len is not None
and prompt_embeds_len != prompt_token_len):
raise ValueError(
"Prompt token ids and prompt embeds had different lengths"
f" prompt_token_ids={prompt_token_len}"
f" prompt_embeds={prompt_embeds_len}")
return prompt_token_len
...@@ -11,6 +11,7 @@ from vllm._bc_linter import bc_linter_include ...@@ -11,6 +11,7 @@ from vllm._bc_linter import bc_linter_include
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata) KVConnectorMetadata)
...@@ -26,13 +27,14 @@ if TYPE_CHECKING: ...@@ -26,13 +27,14 @@ if TYPE_CHECKING:
class NewRequestData: class NewRequestData:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: Optional[list[int]]
mm_features: list[MultiModalFeatureSpec] mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
prompt_embeds: Optional[torch.Tensor] = None
@classmethod @classmethod
def from_request( def from_request(
...@@ -49,9 +51,12 @@ class NewRequestData: ...@@ -49,9 +51,12 @@ class NewRequestData:
block_ids=block_ids, block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request, lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
) )
def __repr__(self): def __repr__(self) -> str:
prompt_embeds_shape = (self.prompt_embeds.shape
if self.prompt_embeds else None)
return (f"NewRequestData(" return (f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids}," f"prompt_token_ids={self.prompt_token_ids},"
...@@ -59,19 +64,26 @@ class NewRequestData: ...@@ -59,19 +64,26 @@ class NewRequestData:
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")") ")")
# Version of __repr__ with the prompt data obfuscated # Version of __repr__ with the prompt data obfuscated
def anon_repr(self): def anon_repr(self) -> str:
prompt_token_ids_len = len(
self.prompt_token_ids
) if self.prompt_token_ids is not None else None
prompt_embeds_shape = (self.prompt_embeds.shape
if self.prompt_embeds else None)
return (f"NewRequestData(" return (f"NewRequestData("
f"req_id={self.req_id}," f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)}," f"prompt_token_ids_len={prompt_token_ids_len},"
f"mm_features={self.mm_features}," f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params}," f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids}," f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens}," f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}" f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")") ")")
......
...@@ -47,7 +47,7 @@ class EngineCoreRequest( ...@@ -47,7 +47,7 @@ class EngineCoreRequest(
gc=False): # type: ignore[call-arg] gc=False): # type: ignore[call-arg]
request_id: str request_id: str
prompt_token_ids: list[int] prompt_token_ids: Optional[list[int]]
mm_features: Optional[list[MultiModalFeatureSpec]] mm_features: Optional[list[MultiModalFeatureSpec]]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
...@@ -56,6 +56,7 @@ class EngineCoreRequest( ...@@ -56,6 +56,7 @@ class EngineCoreRequest(
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
cache_salt: Optional[str] cache_salt: Optional[str]
data_parallel_rank: Optional[int] data_parallel_rank: Optional[int]
prompt_embeds: Optional[torch.Tensor] = None
# Index of the client, used to ensure outputs are sent back to the same # Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end. # client for this request when scaling out the front-end.
......
...@@ -13,6 +13,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker ...@@ -13,6 +13,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import ( from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -179,11 +180,12 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): ...@@ -179,11 +180,12 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
self.tokenizer: Tokenizer = tokenizer._tokenizer self.tokenizer: Tokenizer = tokenizer._tokenizer
# Find a safe place to start. # Find a safe place to start.
prompt_suffix = request.prompt_token_ids prompt_token_ids = request.prompt_token_ids or []
prompt_suffix = prompt_token_ids
prompt_len = len(prompt_suffix) prompt_len = len(prompt_suffix)
if prompt_len > 4: if prompt_len > 4:
for i in range(4, min(prompt_len + 1, 24)): for i in range(4, min(prompt_len + 1, 24)):
suffix = request.prompt_token_ids[-i:] suffix = prompt_token_ids[-i:]
if '�' not in self.tokenizer.decode(suffix): if '�' not in self.tokenizer.decode(suffix):
prompt_suffix = suffix prompt_suffix = suffix
break break
...@@ -260,16 +262,25 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): ...@@ -260,16 +262,25 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
params = request.sampling_params params = request.sampling_params
assert params is not None assert params is not None
self.prompt_len = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
# Metadata for incremental detokenization. # Metadata for incremental detokenization.
self.tokens, self.prefix_offset, self.read_offset = ( if request.prompt_token_ids is not None:
convert_prompt_ids_to_tokens( self.tokens, self.prefix_offset, self.read_offset = (
tokenizer=tokenizer, convert_prompt_ids_to_tokens(
prompt_ids=request.prompt_token_ids, tokenizer=tokenizer,
skip_special_tokens=params.skip_special_tokens, prompt_ids=request.prompt_token_ids,
)) skip_special_tokens=params.skip_special_tokens,
))
else:
# Prompt embedding requests cannot be detokenized, in general.
self.tokens = [""] * self.prompt_len
self.prefix_offset = 0
self.read_offest = 0
self.token_ids.extend(request.prompt_token_ids) self.token_ids.extend(request.prompt_token_ids
self.prompt_len = len(request.prompt_token_ids) or [0] * self.prompt_len)
self.skip_special_tokens = params.skip_special_tokens self.skip_special_tokens = params.skip_special_tokens
self.spaces_between_special_tokens = ( self.spaces_between_special_tokens = (
......
...@@ -14,6 +14,7 @@ from vllm.sampling_params import RequestOutputKind ...@@ -14,6 +14,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, Tracer, from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
extract_trace_context) extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.logprobs import LogprobsProcessor
...@@ -86,7 +87,8 @@ class RequestState: ...@@ -86,7 +87,8 @@ class RequestState:
lora_name: Optional[str], lora_name: Optional[str],
output_kind: RequestOutputKind, output_kind: RequestOutputKind,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: list[int], prompt_token_ids: Optional[list[int]],
prompt_embeds: Optional[torch.Tensor],
logprobs_processor: Optional[LogprobsProcessor], logprobs_processor: Optional[LogprobsProcessor],
detokenizer: Optional[IncrementalDetokenizer], detokenizer: Optional[IncrementalDetokenizer],
max_tokens_param: Optional[int], max_tokens_param: Optional[int],
...@@ -104,7 +106,9 @@ class RequestState: ...@@ -104,7 +106,9 @@ class RequestState:
self.output_kind = output_kind self.output_kind = output_kind
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids) self.prompt_embeds = prompt_embeds
self.prompt_len = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
self.logprobs_processor = logprobs_processor self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer self.detokenizer = detokenizer
self.max_tokens_param = max_tokens_param self.max_tokens_param = max_tokens_param
...@@ -165,6 +169,7 @@ class RequestState: ...@@ -165,6 +169,7 @@ class RequestState:
output_kind=output_kind, output_kind=output_kind,
prompt=prompt, prompt=prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
logprobs_processor=logprobs_processor, logprobs_processor=logprobs_processor,
detokenizer=detokenizer, detokenizer=detokenizer,
max_tokens_param=max_tokens_param, max_tokens_param=max_tokens_param,
...@@ -223,6 +228,8 @@ class RequestState: ...@@ -223,6 +228,8 @@ class RequestState:
first_output = outputs[0] first_output = outputs[0]
if isinstance(first_output, PoolingOutput): if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1 assert len(outputs) == 1
# Prompt embeddings are currently not supported by pooling requests.
assert self.prompt_token_ids is not None
return PoolingRequestOutput( return PoolingRequestOutput(
request_id=request_id, request_id=request_id,
outputs=first_output, outputs=first_output,
...@@ -236,10 +243,15 @@ class RequestState: ...@@ -236,10 +243,15 @@ class RequestState:
else: else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs prompt_logprobs = self.logprobs_processor.prompt_logprobs
# If prompt embeds were used, put placeholder prompt token ids
prompt_token_ids = self.prompt_token_ids
if prompt_token_ids is None and self.prompt_embeds is not None:
prompt_token_ids = [0] * len(self.prompt_embeds)
return RequestOutput( return RequestOutput(
request_id=request_id, request_id=request_id,
prompt=self.prompt, prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
outputs=cast(list[CompletionOutput], outputs), outputs=cast(list[CompletionOutput], outputs),
finished=finished, finished=finished,
...@@ -469,6 +481,8 @@ class OutputProcessor: ...@@ -469,6 +481,8 @@ class OutputProcessor:
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers) trace_context = extract_trace_context(engine_core_output.trace_headers)
prompt_length = length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds)
with (self.tracer.start_as_current_span( with (self.tracer.start_as_current_span(
"llm_request", "llm_request",
kind=SpanKind.SERVER, kind=SpanKind.SERVER,
...@@ -488,7 +502,7 @@ class OutputProcessor: ...@@ -488,7 +502,7 @@ class OutputProcessor:
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
queued_time) queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(req_state.prompt_token_ids)) prompt_length)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens) metrics.num_generation_tokens)
span.set_attribute( span.set_attribute(
...@@ -544,7 +558,8 @@ class OutputProcessor: ...@@ -544,7 +558,8 @@ class OutputProcessor:
assert req_state.stats is not None assert req_state.stats is not None
iteration_stats.update_from_finished_request( iteration_stats.update_from_finished_request(
finish_reason=finish_reason, finish_reason=finish_reason,
num_prompt_tokens=len(req_state.prompt_token_ids), num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
req_state.prompt_token_ids, req_state.prompt_embeds),
max_tokens_param=req_state.max_tokens_param, max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats) req_stats=req_state.stats)
self.lora_states.finish_request(req_state) self.lora_states.finish_request(req_state)
......
...@@ -19,6 +19,7 @@ from vllm.multimodal.utils import argsort_mm_positions ...@@ -19,6 +19,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar) validate_guidance_grammar)
...@@ -390,6 +391,16 @@ class Processor: ...@@ -390,6 +391,16 @@ class Processor:
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# Mypy does not always properly infer the types of some elements of
# discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want
# we can avoid type errors from using `dict.get` later in the method.
prompt_str: Optional[str] = None if decoder_inputs[
"type"] == "embeds" else decoder_inputs.get("prompt")
prompt_token_ids = decoder_inputs[
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
"type"] == "embeds" else None
sampling_params = None sampling_params = None
pooling_params = None pooling_params = None
...@@ -398,9 +409,10 @@ class Processor: ...@@ -398,9 +409,10 @@ class Processor:
sampling_params = params.clone() sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len. # If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None: if sampling_params.max_tokens is None:
sampling_params.max_tokens = ( seq_len = length_from_prompt_token_ids_or_embeds(
self.model_config.max_model_len - prompt_token_ids, prompt_embeds)
len(decoder_inputs["prompt_token_ids"])) sampling_params.max_tokens = \
self.model_config.max_model_len - seq_len
sampling_params.update_from_generation_config( sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id) self.generation_config_fields, eos_token_id)
if self.tokenizer is not None: if self.tokenizer is not None:
...@@ -430,9 +442,10 @@ class Processor: ...@@ -430,9 +442,10 @@ class Processor:
identifier=decoder_mm_hashes[modality][idx], identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx])) mm_position=decoder_mm_positions[modality][idx]))
return decoder_inputs.get("prompt"), EngineCoreRequest( return prompt_str, EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"], prompt_token_ids=prompt_token_ids,
prompt_embeds=prompt_embeds,
mm_features=mm_features, mm_features=mm_features,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
...@@ -461,10 +474,17 @@ class Processor: ...@@ -461,10 +474,17 @@ class Processor:
): ):
model_config = self.model_config model_config = self.model_config
prompt_ids = prompt_inputs["prompt_token_ids"] prompt_ids = None if prompt_inputs[
"type"] == "embeds" else prompt_inputs["prompt_token_ids"]
prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[
"type"] == "embeds" else None
prompt_len = length_from_prompt_token_ids_or_embeds(
prompt_ids, prompt_embeds)
if not prompt_ids: if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data pass # Mllama may have empty encoder inputs for text-only data
elif prompt_inputs["type"] == "embeds":
pass # Prompt embeds should not have prompt_ids.
else: else:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")
...@@ -472,7 +492,7 @@ class Processor: ...@@ -472,7 +492,7 @@ class Processor:
tokenizer = None tokenizer = None
else: else:
tokenizer = self.tokenizer tokenizer = self.tokenizer
max_input_id = max(prompt_ids, default=0) max_input_id = max(prompt_ids or [], default=0)
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
# self.model_config.get_vocab_size() is the model’s vocab size. # self.model_config.get_vocab_size() is the model’s vocab size.
...@@ -490,7 +510,7 @@ class Processor: ...@@ -490,7 +510,7 @@ class Processor:
f"Token id {max_input_id} is out of vocabulary") f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len: if prompt_len > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor( mm_processor = mm_registry.create_processor(
...@@ -514,7 +534,7 @@ class Processor: ...@@ -514,7 +534,7 @@ class Processor:
"number of text tokens.") "number of text tokens.")
raise ValueError( raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is " f"The {prompt_type} prompt (length {prompt_len}) is "
f"longer than the maximum model length of {max_prompt_len}. " f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}") f"{suggestion}")
......
...@@ -7,9 +7,12 @@ from collections.abc import Mapping ...@@ -7,9 +7,12 @@ from collections.abc import Mapping
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason) EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.structured_output.request import StructuredOutputRequest
...@@ -25,12 +28,13 @@ class Request: ...@@ -25,12 +28,13 @@ class Request:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: Optional[list[int]],
sampling_params: Optional[SamplingParams], sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams], pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int], eos_token_id: Optional[int],
client_index: int = 0, client_index: int = 0,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
prompt_embeds: Optional[torch.Tensor] = None,
mm_features: Optional[list[MultiModalFeatureSpec]] = None, mm_features: Optional[list[MultiModalFeatureSpec]] = None,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
...@@ -79,9 +83,13 @@ class Request: ...@@ -79,9 +83,13 @@ class Request:
"sampling_params and pooling_params can't both be unset") "sampling_params and pooling_params can't both be unset")
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.num_prompt_tokens = len(self.prompt_token_ids) self.prompt_embeds = prompt_embeds
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds)
self._output_token_ids: list[int] = [] self._output_token_ids: list[int] = []
self._all_token_ids: list[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy(
) if self.prompt_token_ids is not None else [0
] * self.num_prompt_tokens
self.num_output_placeholders = 0 # Used in async scheduling. self.num_output_placeholders = 0 # Used in async scheduling.
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
...@@ -123,6 +131,7 @@ class Request: ...@@ -123,6 +131,7 @@ class Request:
request_id=request.request_id, request_id=request.request_id,
client_index=request.client_index, client_index=request.client_index,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
prompt_embeds=request.prompt_embeds,
mm_features=request.mm_features, mm_features=request.mm_features,
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
pooling_params=request.pooling_params, pooling_params=request.pooling_params,
......
...@@ -243,7 +243,7 @@ class AdapterLogitsProcessor(LogitsProcessor): ...@@ -243,7 +243,7 @@ class AdapterLogitsProcessor(LogitsProcessor):
def _new_state( def _new_state(
self, self,
params: SamplingParams, params: SamplingParams,
prompt_ids: list[int], prompt_ids: Optional[list[int]],
output_ids: list[int], output_ids: list[int],
) -> Optional[partial[torch.Tensor]]: ) -> Optional[partial[torch.Tensor]]:
"""Return state representation for new request """Return state representation for new request
......
...@@ -187,7 +187,8 @@ class MinTokensLogitsProcessor(LogitsProcessor): ...@@ -187,7 +187,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
@staticmethod @staticmethod
def add_request( def add_request(
params: SamplingParams, _: list[int], output_tok_ids: list[int] params: SamplingParams, _: Optional[list[int]],
output_tok_ids: list[int]
) -> Optional[tuple[int, Sequence[int], set[int]]]: ) -> Optional[tuple[int, Sequence[int], set[int]]]:
min_tokens = params.min_tokens min_tokens = params.min_tokens
if not min_tokens or len(output_tok_ids) >= min_tokens: if not min_tokens or len(output_tok_ids) >= min_tokens:
...@@ -234,7 +235,8 @@ class MinTokensLogitsProcessor(LogitsProcessor): ...@@ -234,7 +235,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
def process_dict_updates( def process_dict_updates(
req_entries: dict[int, T], batch_update: Optional[BatchUpdate], req_entries: dict[int, T], batch_update: Optional[BatchUpdate],
new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] new_state: Callable[[SamplingParams, Optional[list[int]], list[int]],
Optional[T]]
) -> bool: ) -> bool:
"""Utility function to update dict state for sparse LogitsProcessors.""" """Utility function to update dict state for sparse LogitsProcessors."""
......
...@@ -26,7 +26,7 @@ RemovedRequest = int ...@@ -26,7 +26,7 @@ RemovedRequest = int
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new # (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch. # requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]] AddedRequest = tuple[int, SamplingParams, Optional[list[int]], list[int]]
# (index 1, index 2, directionality) tuples representing # (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch # one-way moves or two-way swaps of requests in batch
......
...@@ -174,7 +174,7 @@ class MsgpackEncoder: ...@@ -174,7 +174,7 @@ class MsgpackEncoder:
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes # view the tensor as a contiguous 1D array of bytes
arr = obj.flatten().contiguous().view(torch.uint8).numpy() arr = obj.flatten().contiguous().cpu().view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold: if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays. # Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
......
...@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest ...@@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
...@@ -29,7 +29,7 @@ from vllm.v1.worker.block_table import MultiGroupBlockTable ...@@ -29,7 +29,7 @@ from vllm.v1.worker.block_table import MultiGroupBlockTable
class CachedRequestState: class CachedRequestState:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: Optional[list[int]]
mm_features: list[MultiModalFeatureSpec] mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams] pooling_params: Optional[PoolingParams]
...@@ -43,9 +43,11 @@ class CachedRequestState: ...@@ -43,9 +43,11 @@ class CachedRequestState:
mrope_position_delta: Optional[int] = None mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids) self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@property @property
def num_tokens(self) -> int: def num_tokens(self) -> int:
...@@ -63,6 +65,10 @@ class CachedRequestState: ...@@ -63,6 +65,10 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int: def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens: if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx] return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids): elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens] return self.output_token_ids[idx - self.num_prompt_tokens]
...@@ -109,6 +115,14 @@ class InputBatch: ...@@ -109,6 +115,14 @@ class InputBatch:
pin_memory=False, pin_memory=False,
) )
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
...@@ -310,15 +324,23 @@ class InputBatch: ...@@ -310,15 +324,23 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids) num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids) end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index, self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu. self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens. # NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens. # Number of tokens without spec decode tokens.
...@@ -503,6 +525,20 @@ class InputBatch: ...@@ -503,6 +525,20 @@ class InputBatch:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp self.token_ids_cpu[i2, ...] = tmp
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)
self.block_table.swap_row(i1, i2) self.block_table.swap_row(i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
...@@ -592,6 +628,11 @@ class InputBatch: ...@@ -592,6 +628,11 @@ class InputBatch:
num_tokens = self.num_tokens[last_req_index] num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens] last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index] last_req_index]
......
...@@ -56,7 +56,9 @@ from vllm.sequence import IntermediateTensors, PoolerOutput ...@@ -56,7 +56,9 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, check_use_alibi, get_dtype_size, GiB_bytes, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up, supports_dynamo) is_pin_memory_available,
length_from_prompt_token_ids_or_embeds, round_up,
supports_dynamo)
from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
...@@ -197,6 +199,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -197,6 +199,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cache_config.cache_dtype] cache_config.cache_dtype]
self.is_pooling_model = (model_config.runner_type == 'pooling') self.is_pooling_model = (model_config.runner_type == 'pooling')
self.enable_prompt_embeds = model_config.enable_prompt_embeds
self.is_multimodal_raw_input_only_model = ( self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model) model_config.is_multimodal_raw_input_only_model)
...@@ -342,6 +345,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -342,6 +345,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size, self.hidden_size,
dtype=self.dtype, dtype=self.dtype,
numpy=False) numpy=False)
self.is_token_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
self.discard_request_indices = self._make_buffer(self.max_num_reqs, self.discard_request_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int64) dtype=torch.int64)
self.num_discarded_requests = 0 self.num_discarded_requests = 0
...@@ -574,6 +579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -574,6 +579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = CachedRequestState( req_state = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
mm_features=new_req_data.mm_features, mm_features=new_req_data.mm_features,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=pooling_params, pooling_params=pooling_params,
...@@ -819,6 +825,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -819,6 +825,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.input_batch.prev_sampled_token_ids is None: if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case # Normal scheduling case
self.input_ids.copy_to_gpu(total_num_scheduled_tokens) self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return return
# Async scheduling case, where some decode requests from the previous # Async scheduling case, where some decode requests from the previous
...@@ -844,6 +852,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -844,6 +852,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# If not all requests are decodes from the last iteration, # If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first. # We need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens) self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0: if num_commmon_tokens == 0:
# No requests in common with the previous iteration # No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids. # So input_ids_cpu will have all the input ids.
...@@ -857,6 +867,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -857,6 +867,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0], 0],
non_blocking=True) non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
return return
# Upload the index tensors asynchronously # Upload the index tensors asynchronously
# so the scatter can be non-blocking. # so the scatter can be non-blocking.
...@@ -947,14 +958,60 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -947,14 +958,60 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# where M is the max_model_len. # where M is the max_model_len.
token_indices = (positions_np + token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1]) req_indices * self.input_batch.token_ids_cpu.shape[1])
token_indices_tensor = torch.from_numpy(token_indices)
# NOTE(woosuk): We use torch.index_select instead of np.take here # NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large # because torch.index_select is much faster than np.take for large
# tensors. # tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0, 0,
torch.from_numpy(token_indices), token_indices_tensor,
out=self.input_ids.cpu[:total_num_scheduled_tokens]) out=self.input_ids.cpu[:total_num_scheduled_tokens])
is_token_ids = self.input_batch.is_token_ids.flatten()
torch.index_select(
is_token_ids,
0,
token_indices_tensor,
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
# the InputBatch, we need to fill in the prompt embeds into the expected
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
if self.input_batch.req_prompt_embeds:
output_idx = 0
for req_idx in range(num_reqs):
num_sched = num_scheduled_tokens[req_idx]
# Skip if this request doesn't have embeddings
if req_idx not in self.input_batch.req_prompt_embeds:
output_idx += num_sched
continue
# Skip if no tokens scheduled
if num_sched <= 0:
output_idx += num_sched
continue
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
# Skip if trying to read beyond available embeddings
if start_pos >= req_embeds.shape[0]:
output_idx += num_sched
continue
# Copy available embeddings
end_pos = start_pos + num_sched
actual_end = min(end_pos, req_embeds.shape[0])
actual_num_sched = actual_end - start_pos
if actual_num_sched > 0:
self.inputs_embeds.cpu[output_idx:output_idx +
actual_num_sched].copy_(
req_embeds[start_pos:actual_end]
)
output_idx += num_sched
self.input_batch.block_table.compute_slot_mapping( self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np) req_indices, positions_np)
...@@ -1279,7 +1336,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1279,7 +1336,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[index] self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \ num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id] scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids) num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0, prompt_part_len = max(0,
...@@ -1845,6 +1903,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1845,6 +1903,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**self._init_model_kwargs(num_scheduled_tokens), **self._init_model_kwargs(num_scheduled_tokens),
**self._extract_mm_kwargs(scheduler_output), **self._extract_mm_kwargs(scheduler_output),
} }
elif (self.enable_prompt_embeds and get_pp_group().is_first_rank):
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
# TODO(qthequartermasterman): Since even when prompt embeds are
# enabled, (a) not all requests will use prompt embeds, and (b)
# after the initial prompt is processed, the rest of the generated
# tokens will be token ids, it is not desirable to have the
# embedding layer outside of the CUDA graph all the time. The v0
# engine avoids this by "double compiling" the CUDA graph, once
# with input_ids and again with inputs_embeds, for all num_tokens.
# If a batch only has token ids, then including the embedding layer
# in the CUDA graph will be more performant (like in the else case
# below).
token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \
.nonzero(as_tuple=False) \
.squeeze(1)
# Some tokens ids may need to become embeds
if token_ids_idx.numel() > 0:
token_ids = self.input_ids.gpu[token_ids_idx]
tokens_to_embeds = self.model.get_input_embeddings(
input_ids=token_ids)
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
model_kwargs = self._init_model_kwargs(num_input_tokens)
input_ids = None
else: else:
# For text-only models, we use token ids as input. # For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the # While it is possible to use embeddings as input just like the
...@@ -2023,6 +2107,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2023,6 +2107,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.token_ids_cpu[req_idx, self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx
...@@ -2570,6 +2655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2570,6 +2655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Get metadata for this request. # Get metadata for this request.
request = self.requests[req_id] request = self.requests[req_id]
if request.prompt_token_ids is None:
# Prompt logprobs is incompatible with prompt embeddings
continue
num_prompt_tokens = len(request.prompt_token_ids) num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to( prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True) self.device, non_blocking=True)
...@@ -2922,6 +3011,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -2922,6 +3011,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**model_kwargs, **model_kwargs,
**self._dummy_mm_kwargs(num_reqs), **self._dummy_mm_kwargs(num_reqs),
} }
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
model_kwargs = self._init_model_kwargs(num_tokens)
else: else:
input_ids = self.input_ids.gpu[:num_tokens] input_ids = self.input_ids.gpu[:num_tokens]
inputs_embeds = None inputs_embeds = None
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.utils import swap_dict_values from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_input_batch import CachedRequestState
...@@ -213,7 +213,9 @@ class InputBatch: ...@@ -213,7 +213,9 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids) num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
# TODO: copy prompt_embeds
self.num_prompt_tokens[req_index] = num_prompt_tokens self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[ self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids req_index, :num_prompt_tokens] = request.prompt_token_ids
......
...@@ -387,6 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -387,6 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids, prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
mm_features=new_req_data.mm_features, mm_features=new_req_data.mm_features,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None, pooling_params=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment