Unverified Commit cee711fd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Rename input data types (#8688)

parent 1de76a0e
...@@ -15,8 +15,8 @@ from typing_extensions import NotRequired ...@@ -15,8 +15,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
from vllm.logger import init_logger token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -37,8 +37,6 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, ...@@ -37,8 +37,6 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
...@@ -252,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, ...@@ -252,10 +250,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
def input_processor_when_multimodal_input_image(ctx: InputContext, def input_processor_when_multimodal_input_image(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig) hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
...@@ -290,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, ...@@ -290,7 +288,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
return input_processor_for_clip( return input_processor_for_clip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -298,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, ...@@ -298,7 +296,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
return input_processor_for_siglip( return input_processor_for_siglip(
model_config, model_config,
vision_config, vision_config,
llm_inputs, inputs,
image_token_id=hf_config.image_token_index, image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size, image_feature_size_override=image_feature_size,
) )
...@@ -308,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext, ...@@ -308,10 +306,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
def input_processor_when_multimodal_input_video(ctx: InputContext, def input_processor_when_multimodal_input_video(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data: if multi_modal_data is None or "video" not in multi_modal_data:
return llm_inputs return inputs
video_data = multi_modal_data["video"] video_data = multi_modal_data["video"]
model_config = ctx.model_config model_config = ctx.model_config
...@@ -326,13 +324,13 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, ...@@ -326,13 +324,13 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index, placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size, repeat_count=video_feature_size,
) )
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
...@@ -345,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext, ...@@ -345,15 +343,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
def input_processor_for_llava_onevision(ctx: InputContext, def input_processor_for_llava_onevision(ctx: InputContext,
llm_inputs: LLMInputs): inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data): and "image" not in multi_modal_data):
return llm_inputs return inputs
if "image" in multi_modal_data: if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, llm_inputs) return input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data: if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, llm_inputs) return input_processor_when_multimodal_input_video(ctx, inputs)
msg = "Unsupported multi data type" msg = "Unsupported multi data type"
raise NotImplementedError(msg) raise NotImplementedError(msg)
......
...@@ -36,7 +36,8 @@ from typing_extensions import NotRequired ...@@ -36,7 +36,8 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
...@@ -256,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext): ...@@ -256,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
return SequenceData.from_token_counts((0, seq_len)) return SequenceData.from_prompt_token_counts((0, seq_len))
def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
...@@ -279,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, ...@@ -279,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
return seq_data, mm_data return seq_data, mm_data
def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
version = get_version_by_config(model_config.hf_config) version = get_version_by_config(model_config.hf_config)
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
...@@ -297,8 +298,8 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -297,8 +298,8 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
return image_processor. \ return image_processor. \
get_slice_image_placeholder(image_size, num_image) get_slice_image_placeholder(image_size, num_image)
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
token_ids = llm_inputs.get("prompt_token_ids") token_ids = inputs.get("prompt_token_ids")
if prompt is None: if prompt is None:
prompt = tokenizer.decode(token_ids) prompt = tokenizer.decode(token_ids)
...@@ -332,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -332,12 +333,11 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
_build_image_input(ctx, image) for image in images _build_image_input(ctx, image) for image in images
] ]
llm_inputs = LLMInputs( return token_inputs(
prompt_token_ids=new_token_ids, prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
return llm_inputs
def input_mapper_for_minicpmv(ctx: InputContext, data: object): def input_mapper_for_minicpmv(ctx: InputContext, data: object):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
import math import math
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -37,7 +36,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType ...@@ -37,7 +36,8 @@ from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputContext)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import SequenceData
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
...@@ -86,24 +86,24 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: ...@@ -86,24 +86,24 @@ def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
return num_images return num_images
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_mllama(ctx: InputContext,
inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
# move encoder_prompt to prompt # move encoder_prompt to prompt
if llm_inputs.get("prompt") is None: if inputs.get("prompt") is None:
llm_inputs["prompt"] = llm_inputs["encoder_prompt"] inputs["prompt"] = inputs["encoder_prompt"]
llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
# process multi-modal data # process multi-modal data
assert "decoder_multi_modal_data" not in llm_inputs, \ multi_modal_data = inputs.get("encoder_multi_modal_data")
"multi-modal data should be put in encoder message of mllama"
multi_modal_data = llm_inputs.get("encoder_multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data \ if multi_modal_data is None or "image" not in multi_modal_data \
or multi_modal_data["image"] is None: or multi_modal_data["image"] is None:
# text-only # text-only
llm_inputs["encoder_prompt"] = "" inputs["encoder_prompt"] = ""
llm_inputs["encoder_prompt_token_ids"] = [] inputs["encoder_prompt_token_ids"] = []
llm_inputs["encoder_multi_modal_data"] = {} inputs["encoder_multi_modal_data"] = {}
return llm_inputs return inputs
if isinstance(multi_modal_data['image'], Image.Image): if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']] multi_modal_data['image'] = [multi_modal_data['image']]
...@@ -111,7 +111,7 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -111,7 +111,7 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# are attended by the decoded tokens, we only need to # are attended by the decoded tokens, we only need to
# get the number of tiles for those images. # get the number of tiles for those images.
num_decode_images = _get_num_image_in_last_group( num_decode_images = _get_num_image_in_last_group(
llm_inputs["prompt_token_ids"]) inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config hf_config = ctx.model_config.hf_config
num_tiles = 0 num_tiles = 0
for image in multi_modal_data["image"][::-1]: for image in multi_modal_data["image"][::-1]:
...@@ -137,11 +137,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -137,11 +137,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
"chunk size should be multiple of 14" "chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk num_tokens = num_tiles * token_per_chunk
llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
] * num_tokens
return llm_inputs return inputs
def get_max_mllama_image_tokens(ctx: InputContext) -> int: def get_max_mllama_image_tokens(ctx: InputContext) -> int:
...@@ -154,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int): ...@@ -154,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images) # <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \ assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images" "seq_len should be greater than or equal to num_images"
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_images return SequenceData.from_prompt_token_counts(
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) (MLLAMA_IMAGE_TOKEN_ID, num_images),
return SequenceData(token_ids) (0, seq_len - num_images),
)
def dummy_encoder_seq_data(ctx: InputContext, num_images: int): def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images num_tokens = get_max_mllama_image_tokens(ctx) * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_tokens return SequenceData.from_prompt_token_counts(
return SequenceData(token_ids) (MLLAMA_IMAGE_TOKEN_ID, num_tokens))
def dummy_image(num_images: int, ): def dummy_image(num_images: int, ):
......
...@@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, ...@@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -945,9 +946,9 @@ def pad_images( ...@@ -945,9 +946,9 @@ def pad_images(
return images, image_input_idx, image_masks return images, image_input_idx, image_masks
def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
prompt = llm_inputs.get("prompt", None) prompt = inputs.get("prompt", None)
multi_modal_data = llm_inputs.get("multi_modal_data", None) multi_modal_data = inputs.get("multi_modal_data", None)
if multi_modal_data is not None: if multi_modal_data is not None:
image = multi_modal_data.get("image", None) image = multi_modal_data.get("image", None)
else: else:
...@@ -965,9 +966,7 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -965,9 +966,7 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
elif prompt is not None: elif prompt is not None:
out = processor.process(prompt, image) out = processor.process(prompt, image)
else: else:
out = processor.process(None, out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
image,
tokens=llm_inputs["prompt_token_ids"])
image_processor = processor.image_processor image_processor = processor.image_processor
max_total_crops = 1 + image_processor.max_crops max_total_crops = 1 + image_processor.max_crops
...@@ -1020,9 +1019,9 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -1020,9 +1019,9 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = dict(image=image_data) multi_modal_data = dict(image=image_data)
return LLMInputs( return token_inputs(
prompt_token_ids=out["input_ids"], prompt_token_ids=out["input_ids"],
prompt=llm_inputs["prompt"], prompt=inputs["prompt"],
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
......
...@@ -7,7 +7,8 @@ from transformers import PaliGemmaConfig ...@@ -7,7 +7,8 @@ from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -68,7 +69,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, ...@@ -68,7 +69,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
return seq_data, mm_data return seq_data, mm_data
def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_paligemma(ctx: InputContext,
inputs: DecoderOnlyInputs):
""" """
The correct prompt format needs to be: The correct prompt format needs to be:
...@@ -77,9 +79,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -77,9 +79,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa """ # noqa
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig) hf_config = ctx.get_hf_config(PaliGemmaConfig)
...@@ -91,8 +93,8 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -91,8 +93,8 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
image_token_str_pad = image_token_str * image_feature_size image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
orig_prompt = llm_inputs.get("prompt") orig_prompt = inputs.get("prompt")
orig_prompt_ids = llm_inputs.get("prompt_token_ids") orig_prompt_ids = inputs.get("prompt_token_ids")
if orig_prompt is not None and image_token_str in orig_prompt: if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning( logger.warning(
...@@ -106,7 +108,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -106,7 +108,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
......
...@@ -27,7 +27,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig ...@@ -27,7 +27,8 @@ from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -410,12 +411,12 @@ def _get_image_placeholder_token_id_candidates( ...@@ -410,12 +411,12 @@ def _get_image_placeholder_token_id_candidates(
def input_processor_for_phi3v(ctx: InputContext, def input_processor_for_phi3v(ctx: InputContext,
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
num_crops: Optional[int] = None): num_crops: Optional[int] = None):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
model_config = ctx.model_config model_config = ctx.model_config
hf_config = ctx.get_hf_image_processor_config() hf_config = ctx.get_hf_image_processor_config()
...@@ -447,7 +448,7 @@ def input_processor_for_phi3v(ctx: InputContext, ...@@ -447,7 +448,7 @@ def input_processor_for_phi3v(ctx: InputContext,
else: else:
raise TypeError(f"Invalid image type: {type(image_data)}") raise TypeError(f"Invalid image type: {type(image_data)}")
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
if prompt is None: if prompt is None:
# for async server request, we assume prompt and its token_ids is always # for async server request, we assume prompt and its token_ids is always
# in correct format. And num_image_tags == len(image_data) always True. # in correct format. And num_image_tags == len(image_data) always True.
...@@ -464,7 +465,7 @@ def input_processor_for_phi3v(ctx: InputContext, ...@@ -464,7 +465,7 @@ def input_processor_for_phi3v(ctx: InputContext,
image_data), "The count of image_placeholder not match image's" image_data), "The count of image_placeholder not match image's"
new_prompt = prompt new_prompt = prompt
prompt_token_ids = llm_inputs["prompt_token_ids"].copy() prompt_token_ids = inputs["prompt_token_ids"].copy()
print("prompt_token_ids (old)", prompt_token_ids) print("prompt_token_ids (old)", prompt_token_ids)
...@@ -506,10 +507,9 @@ def input_processor_for_phi3v(ctx: InputContext, ...@@ -506,10 +507,9 @@ def input_processor_for_phi3v(ctx: InputContext,
new_token_ids.append(token_id) new_token_ids.append(token_id)
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
return llm_inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
......
...@@ -14,7 +14,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask ...@@ -14,7 +14,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -62,7 +62,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ...@@ -62,7 +62,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_feature_size = (size**2) // (patch_size**2) image_feature_size = (size**2) // (patch_size**2)
num_image_tokens = image_feature_size * num_images num_image_tokens = image_feature_size * num_images
seq_data = SequenceData.from_token_counts( seq_data = SequenceData.from_prompt_token_counts(
(image_token_id, num_image_tokens), (image_token_id, num_image_tokens),
(0, seq_len - num_image_tokens), (0, seq_len - num_image_tokens),
) )
...@@ -102,8 +102,8 @@ def input_mapper_for_pixtral(ctx: InputContext, ...@@ -102,8 +102,8 @@ def input_mapper_for_pixtral(ctx: InputContext,
return MultiModalInputs({"images": images}) return MultiModalInputs({"images": images})
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data: if multi_modal_data is not None and "image" in multi_modal_data:
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer, ctx.model_config.tokenizer,
...@@ -112,15 +112,15 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -112,15 +112,15 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img image_token_id = mm_encoder.special_ids.img
if image_token_id not in llm_inputs['prompt_token_ids']: if image_token_id not in inputs['prompt_token_ids']:
raise ValueError( raise ValueError(
(f"You've passed {llm_inputs=} without {image_token_id=}" (f"You've passed {inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's" " Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more" " tokenizer or pass a chat completion request. For more"
" For more info, see: " " For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")) "https://github.com/vllm-project/vllm/issues/8411."))
return llm_inputs return inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
......
...@@ -22,7 +22,8 @@ from transformers import PretrainedConfig ...@@ -22,7 +22,8 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -652,30 +653,30 @@ def get_image_text(image_num: int, padding: bool) -> str: ...@@ -652,30 +653,30 @@ def get_image_text(image_num: int, padding: bool) -> str:
def input_processor_for_qwen(ctx: InputContext, def input_processor_for_qwen(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs: inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
"""Processes the inputs, which may or may not be multimodal. """Processes the inputs, which may or may not be multimodal.
Multimodal inputs will only be processed if the model has a "visual" Multimodal inputs will only be processed if the model has a "visual"
component in its model config, otherwise they'll be ignored. component in its model config, otherwise they'll be ignored.
Args: Args:
ctx: Context of the loaded model. ctx: Context of the loaded model.
llm_inputs: LLM inputs which may have a multi_modal_data attribute. inputs: LLM inputs which may have a multi_modal_data attribute.
Returns: Returns:
If the model is language only or not multimodal inputs were provided, If the model is language only or not multimodal inputs were provided,
returns llm_inputs unmodified. Otherwise, processes the multimodal returns inputs unmodified. Otherwise, processes the multimodal
images / image embeddings and adds the fixed-length image placeholders. images / image embeddings and adds the fixed-length image placeholders.
""" """
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
# Only process images if we have multimodal data and a visual config # Only process images if we have multimodal data and a visual config
hf_config = ctx.get_hf_config() hf_config = ctx.get_hf_config()
if (multi_modal_data is None or "image" not in multi_modal_data if (multi_modal_data is None or "image" not in multi_modal_data
or not hasattr(hf_config, "visual")): or not hasattr(hf_config, "visual")):
return llm_inputs return inputs
prompt = llm_inputs.get("prompt") prompt = inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"] prompt_token_ids = inputs["prompt_token_ids"]
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
...@@ -713,7 +714,7 @@ def input_processor_for_qwen(ctx: InputContext, ...@@ -713,7 +714,7 @@ def input_processor_for_qwen(ctx: InputContext,
new_prompt_token_ids = tokenizer.encode(new_prompt) new_prompt_token_ids = tokenizer.encode(new_prompt)
return LLMInputs(prompt=new_prompt, return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids, prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
...@@ -822,7 +823,7 @@ def dummy_data_for_qwen( ...@@ -822,7 +823,7 @@ def dummy_data_for_qwen(
# The presence of a visual config indicates this is a multimodal model. # The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes. # If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"): if not hasattr(hf_config, "visual"):
seq_data = SequenceData.from_token_counts((0, seq_len)) seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
mm_data = None mm_data = None
return seq_data, mm_data return seq_data, mm_data
......
...@@ -46,7 +46,8 @@ from vllm.attention.selector import (_Backend, backend_name_to_enum, ...@@ -46,7 +46,8 @@ from vllm.attention.selector import (_Backend, backend_name_to_enum,
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
...@@ -716,7 +717,7 @@ def dummy_data_for_qwen2_vl( ...@@ -716,7 +717,7 @@ def dummy_data_for_qwen2_vl(
hf_config = ctx.get_hf_config(Qwen2VLConfig) hf_config = ctx.get_hf_config(Qwen2VLConfig)
dummy_seqdata = SequenceData.from_token_counts( dummy_seqdata = SequenceData.from_prompt_token_counts(
(hf_config.vision_start_token_id, 1), (hf_config.vision_start_token_id, 1),
(hf_config.image_token_id, max_llm_image_tokens), (hf_config.image_token_id, max_llm_image_tokens),
(hf_config.vision_end_token_id, 1), (hf_config.vision_end_token_id, 1),
...@@ -799,11 +800,13 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, ...@@ -799,11 +800,13 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
return prompt_token_ids_with_data return prompt_token_ids_with_data
def input_processor_for_qwen2_vl(ctx: InputContext, def input_processor_for_qwen2_vl(
llm_inputs: LLMInputs) -> LLMInputs: ctx: InputContext,
multi_modal_data = llm_inputs.get("multi_modal_data", None) inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data", None)
if multi_modal_data is None: if multi_modal_data is None:
return llm_inputs return inputs
image_inputs = multi_modal_data.get("image", None) image_inputs = multi_modal_data.get("image", None)
video_inputs = multi_modal_data.get("video", None) video_inputs = multi_modal_data.get("video", None)
...@@ -817,7 +820,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext, ...@@ -817,7 +820,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
# #
# The following code is equivalent to: # The following code is equivalent to:
# prompt = llm_inputs["prompt"] # prompt = inputs["prompt"]
# inputs = processor(text=[prompt], # inputs = processor(text=[prompt],
# images=image_inputs, # images=image_inputs,
# videos=video_inputs, # videos=video_inputs,
...@@ -825,9 +828,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext, ...@@ -825,9 +828,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
# return_tensors="pt") # return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist() # prompt_token_ids = inputs["input_ids"][0].tolist()
prompt_token_ids = llm_inputs.get("prompt_token_ids", None) prompt_token_ids = inputs.get("prompt_token_ids", None)
if prompt_token_ids is None: if prompt_token_ids is None:
prompt = llm_inputs["prompt"] prompt = inputs["prompt"]
prompt_token_ids = processor.tokenizer( prompt_token_ids = processor.tokenizer(
prompt, prompt,
padding=True, padding=True,
...@@ -868,9 +871,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext, ...@@ -868,9 +871,9 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
image_processor, image_processor,
prompt_token_ids) prompt_token_ids)
return LLMInputs( return token_inputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt=llm_inputs["prompt"], prompt=inputs["prompt"],
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
) )
......
...@@ -13,7 +13,7 @@ from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention ...@@ -13,7 +13,7 @@ from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -67,7 +67,7 @@ def dummy_seq_data_for_siglip( ...@@ -67,7 +67,7 @@ def dummy_seq_data_for_siglip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
return SequenceData.from_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) )
...@@ -111,14 +111,14 @@ def dummy_video_for_siglip( ...@@ -111,14 +111,14 @@ def dummy_video_for_siglip(
def input_processor_for_siglip( def input_processor_for_siglip(
model_config: ModelConfig, model_config: ModelConfig,
hf_config: SiglipVisionConfig, hf_config: SiglipVisionConfig,
llm_inputs: LLMInputs, inputs: DecoderOnlyInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
...@@ -135,14 +135,14 @@ def input_processor_for_siglip( ...@@ -135,14 +135,14 @@ def input_processor_for_siglip(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=image_token_id, placeholder_token_id=image_token_id,
repeat_count=image_feature_size, repeat_count=image_feature_size,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs( return token_inputs(
prompt_token_ids=new_token_ids, prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
......
...@@ -18,7 +18,7 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder ...@@ -18,7 +18,7 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs from vllm.inputs.data import DecoderOnlyInputs, token_inputs
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -156,10 +156,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): ...@@ -156,10 +156,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
return MultiModalInputs({"audio_features": audio_features}) return MultiModalInputs({"audio_features": audio_features})
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data: if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs return inputs
feature_extractor = whisper_feature_extractor(ctx) feature_extractor = whisper_feature_extractor(ctx)
audios = multi_modal_data["audio"] audios = multi_modal_data["audio"]
...@@ -196,14 +196,14 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -196,14 +196,14 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
llm_inputs.get("prompt"), inputs.get("prompt"),
llm_inputs["prompt_token_ids"], inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_token_counts, repeat_count=audio_token_counts,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data)
......
...@@ -13,8 +13,7 @@ from typing import Set, Tuple, Union, cast ...@@ -13,8 +13,7 @@ from typing import Set, Tuple, Union, cast
import msgspec import msgspec
import torch import torch
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -22,6 +21,7 @@ from vllm.sampling_params import SamplingParams ...@@ -22,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.inputs import SingletonInputs
from vllm.multimodal.base import MultiModalDataDict from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
...@@ -29,6 +29,11 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" ...@@ -29,6 +29,11 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1 VLLM_INVALID_TOKEN_ID = -1
def array_full(token_id: int, count: int):
""":class:`array` equivalent of :func:`numpy.full`."""
return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
# We use dataclass for now because it is used for # We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable. # openai server output, and msgspec is not serializable.
# TODO(sang): Fix it. # TODO(sang): Fix it.
...@@ -173,22 +178,34 @@ class SequenceData(msgspec.Struct, ...@@ -173,22 +178,34 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
@staticmethod @staticmethod
def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
Each tuple represents one token sequence, expressed in the form
:code:`(token_id, count)`.
"""
if len(token_counts) == 0: if len(token_counts) == 0:
return SequenceData.from_seqs([]) return SequenceData.from_seqs([])
arrs = [ prompt_token_ids_arr = reduce(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count array.__iadd__,
for token_id, count in token_counts (array_full(token_id, count) for token_id, count in token_counts),
] )
return SequenceData(reduce(array.__add__, arrs)) return SequenceData(prompt_token_ids_arr)
@staticmethod @staticmethod
def from_seqs( def from_seqs(
prompt_token_ids: GenericSequence[int], prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None, output_token_ids: Optional[GenericSequence[int]] = None,
) -> "SequenceData": ) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance from prompt and output
token sequences.
"""
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids) prompt_token_ids)
...@@ -362,14 +379,14 @@ class SequenceData(msgspec.Struct, ...@@ -362,14 +379,14 @@ class SequenceData(msgspec.Struct,
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence. """Stores the data, status, and block information of a sequence.
The sequence is constructed from the LLMInputs instance passed The sequence is constructed from the :code:`SingletonInputs` instance
in through the `inputs` constructor argument. passed in through the :code:`inputs` constructor argument.
For encoder/decoder models, LLMInputs encapsulates both a For encoder/decoder models, SingletonInputs encapsulates both a
decoder and encoder prompt, creating an ambiguity about which decoder and encoder prompt, creating an ambiguity about which
prompt to construct the sequence from. The `from_decoder_prompt` prompt to construct the sequence from. The `from_decoder_prompt`
constructor argument signals whether to construct the Sequence constructor argument signals whether to construct the Sequence
from the LLMInputs decoder prompt, or encoder prompt. from the SingletonInputs decoder prompt, or encoder prompt.
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
...@@ -379,16 +396,16 @@ class Sequence: ...@@ -379,16 +396,16 @@ class Sequence:
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request. lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt from_decoder_prompt: Construct Sequence from SingletonInputs decoder
(True) or encoder prompt (False.) Must be True prompt (True) or encoder prompt (False.) Must be
for decoder-only model. True for decoder-only model.
""" """
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
inputs: "LLMInputs", inputs: "SingletonInputs",
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
...@@ -404,19 +421,19 @@ class Sequence: ...@@ -404,19 +421,19 @@ class Sequence:
self.from_decoder_prompt = from_decoder_prompt self.from_decoder_prompt = from_decoder_prompt
# For decoder-only models, a Sequence is constructed # For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.) # from an DecoderOnlyInputs instance (the `inputs` arg.)
# #
# For encoder/decoder models the same `inputs` # For encoder/decoder models the same `inputs`
# instance could be utilized to construct either an # instance could be utilized to construct either an
# encoder sequence or a decoder sequence, because # encoder sequence or a decoder sequence, because
# `LLMInputs` has both decoder- and encoder-oriented # `DecoderOnlyInputs` has both decoder- and encoder-oriented
# member variables (i.e. it encapsulates both an encoder # member variables (i.e. it encapsulates both an encoder
# and a decoder prompt.) The decision of which type of sequence # and a decoder prompt.) The decision of which type of sequence
# to generate is determined by the `from_decoder_prompt` argument. # to generate is determined by the `from_decoder_prompt` argument.
# #
# When constructing a encoder sequence # When constructing a encoder sequence
# (`from_decoder_prompt` False) it matters that # (`from_decoder_prompt` False) it matters that
# the `LLMInputs` instance stored in `inputs` is valid # the `DecoderOnlyInputs` instance stored in `inputs` is valid
# in the sense that its encoder-related member variables are # in the sense that its encoder-related member variables are
# populated; below, an exception is raised if this is # populated; below, an exception is raised if this is
# not the case. # not the case.
...@@ -424,8 +441,7 @@ class Sequence: ...@@ -424,8 +441,7 @@ class Sequence:
# When constructing a decoder sequence (`from_decoder_prompt` True) # When constructing a decoder sequence (`from_decoder_prompt` True)
# it does not matter whether `inputs` has its encoder-related # it does not matter whether `inputs` has its encoder-related
# member variables populated. # member variables populated.
if not (from_decoder_prompt if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
or is_valid_encoder_decoder_llm_inputs(inputs)):
raise ValueError("Cannot extract encoder input prompt from " raise ValueError("Cannot extract encoder input prompt from "
f"invalid input {inputs}; did you forget the " f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?") "encoder input prompt fields?")
...@@ -471,15 +487,19 @@ class Sequence: ...@@ -471,15 +487,19 @@ class Sequence:
@property @property
def multi_modal_data(self) -> "MultiModalDataDict": def multi_modal_data(self) -> "MultiModalDataDict":
if self.inputs.get("multi_modal_data") and self.inputs.get( inputs = self.inputs
"encoder_multi_modal_data"):
if (inputs.get("multi_modal_data")
and inputs.get("encoder_multi_modal_data")):
raise ValueError( raise ValueError(
"Multi-modal data in both encoder and decoder is not supported." "Multi-modal data in both encoder and decoder is not supported."
) )
inputs = self.inputs
return self.inputs.get("multi_modal_data") or (cast( return cast(
EncoderDecoderLLMInputs, "MultiModalDataDict",
inputs).get("encoder_multi_modal_data")) or {} (inputs.get("multi_modal_data")
or inputs.get("encoder_multi_modal_data") or {}),
)
@property @property
def mm_processor_kwargs(self) -> Dict[str, Any]: def mm_processor_kwargs(self) -> Dict[str, Any]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment