Unverified Commit f2ebb6f5 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[V1] Scatter and gather placeholders in the model runner (#16076)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarJennifer Zhao <ai.jenniferzhao@gmail.com>
parent 1d012112
...@@ -33,7 +33,6 @@ from vllm.attention.layer import MultiHeadAttention ...@@ -33,7 +33,6 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -50,7 +49,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ...@@ -50,7 +49,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -58,9 +57,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP ...@@ -58,9 +57,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
class Llama4ImagePatchInputs(TypedDict): class Llama4ImagePatchInputs(TypedDict):
...@@ -77,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict): ...@@ -77,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
""" """
A list of aspect ratios corresponding to the number of tiles A list of aspect ratios corresponding to the number of tiles
...@@ -510,11 +502,10 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -510,11 +502,10 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
# image_start + local tiles * (patches + 1 x separator) + patch_per_chunk = self.get_patch_per_chunk(vision_config)
# 1 global tile * (image x 1 + patches) + image_end num_patches = self.get_max_num_tiles() + 1
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2 return {"image": patch_per_chunk * num_patches}
return {"image": mm_max_tokens}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
...@@ -523,6 +514,14 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -523,6 +514,14 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
return ImageSize(height=self.get_max_num_tiles() * image_size, return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size) width=image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
): ):
...@@ -578,33 +577,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -578,33 +577,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for (r_h, r_w) in aspect_ratios for (r_h, r_w) in aspect_ratios
] ]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
num_patches_per_chunk = self.info.get_patch_per_chunk(
vision_config)
expanded_image_tokens_list = [
processor._prompt_split_image(aspect_ratio,
num_patches_per_chunk)
for aspect_ratio in aspect_ratios
]
expanded_image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in expanded_image_tokens_list
]
embed_is_patch = [
torch.tensor(tokens) == patch_id
for tokens in expanded_image_token_ids
]
processed_outputs["aspect_ratios"] = aspect_ratios processed_outputs["aspect_ratios"] = aspect_ratios
processed_outputs["patches_per_image"] = torch.tensor( processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image) patches_per_image)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
...@@ -619,7 +594,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -619,7 +594,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
"image", patches_per_image), "image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"), patches_per_image=MultiModalFieldConfig.batched("image"),
aspect_ratios=MultiModalFieldConfig.batched("image"), aspect_ratios=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -639,12 +613,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -639,12 +613,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config) num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token image_token = hf_processor.image_token
img_patch_token = hf_processor.img_patch_token
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
return hf_processor._prompt_split_image(
repl = hf_processor._prompt_split_image(
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
num_patches_per_chunk=num_patches_per_chunk) num_patches_per_chunk=num_patches_per_chunk,
)
return PromptUpdateDetails.select_text(repl, img_patch_token)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -737,11 +716,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -737,11 +716,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
flat_pixel_values = flatten_bn(pixel_values, concat=True) flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
aspect_ratios = kwargs.pop("aspect_ratios", None) aspect_ratios = kwargs.pop("aspect_ratios", None)
if not isinstance(aspect_ratios, (torch.Tensor, list)): if not isinstance(aspect_ratios, (torch.Tensor, list)):
raise ValueError("Incorrect type of aspect_ratios. " raise ValueError("Incorrect type of aspect_ratios. "
...@@ -751,7 +725,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -751,7 +725,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
type="pixel_values", type="pixel_values",
flat_data=flat_pixel_values, flat_data=flat_pixel_values,
patches_per_image=patches_per_image, patches_per_image=patches_per_image,
embed_is_patch=embed_is_patch,
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
) )
...@@ -759,10 +732,15 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -759,10 +732,15 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
flat_data = image_input["flat_data"] flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist() patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data) vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector( vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat) vision_embeddings_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
return [
img.flatten(0, 1)
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]: **kwargs) -> Optional[MultiModalEmbeddings]:
...@@ -770,20 +748,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -770,20 +748,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None: if image_input is None:
return None return None
# num_images x [num_chunks, num_patches, hidden_dim] return self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat = [img.flatten(0, 1) for img in image_features]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat = [
is_patch.flatten(0, 1)
for is_patch in image_input["embed_is_patch"]
]
return scatter_patch_features(
image_features_flat,
embed_is_patch_flat,
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -794,9 +759,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -794,9 +759,11 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, input_ids,
select_patch_features(multimodal_embeddings), inputs_embeds,
self.config.image_token_index) multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds return inputs_embeds
......
...@@ -46,7 +46,8 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ...@@ -46,7 +46,8 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -56,7 +57,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -56,7 +57,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
...@@ -84,14 +84,6 @@ class MolmoImageInputs(TypedDict): ...@@ -84,14 +84,6 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_crops, num_patch)` Shape: `(batch_size * num_images, num_crops, num_patch)`
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: torch.Tensor num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
...@@ -1146,30 +1138,6 @@ class MolmoProcessorWrapper: ...@@ -1146,30 +1138,6 @@ class MolmoProcessorWrapper:
if image_input_idx is not None: if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0 feat_is_patch = image_input_idx >= 0
input_is_embed = torch.isin(
input_ids,
torch.tensor([
self.image_patch_id,
self.im_col_id,
self.im_start_id,
self.im_end_id,
]),
)
embed_ids = input_ids[input_is_embed]
embed_is_patch = embed_ids == self.image_patch_id
assert embed_is_patch.sum() == feat_is_patch.sum()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
assert len(embed_start) == len(embed_end) == len(images)
embed_is_patch = [
embed_is_patch[start:end + 1]
for start, end in zip(embed_start, embed_end)
]
tilings = [ tilings = [
self.select_tiling( self.select_tiling(
image_width=image.size[0], image_width=image.size[0],
...@@ -1181,7 +1149,6 @@ class MolmoProcessorWrapper: ...@@ -1181,7 +1149,6 @@ class MolmoProcessorWrapper:
assert num_crops.sum() == len(feat_is_patch) assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch outputs["feat_is_patch"] = feat_is_patch
outputs["embed_is_patch"] = embed_is_patch
outputs["num_crops"] = num_crops outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id outputs["img_patch_id"] = self.image_patch_id
...@@ -1220,17 +1187,13 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1220,17 +1187,13 @@ class MolmoProcessingInfo(BaseProcessingInfo):
) )
pooling_size = processor.pooling_size pooling_size = processor.pooling_size
base_image_input_size = processor.base_image_input_size image_token_length_w = processor.image_token_length_w
base_image_input_d = processor.image_patch_size image_token_length_h = processor.image_token_length_h
crop_patches = base_image_input_size[0] // base_image_input_d
per_row = ncols // pooling_size + 1 extra = image_token_length_w * image_token_length_h
joint = per_row * (nrows // pooling_size) + 2 joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
resize = (image_token_length + 1) * image_token_length + 2
return resize + joint return extra + joint
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
...@@ -1328,7 +1291,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1328,7 +1291,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image", num_crops), "image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes( feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops), "image", num_crops),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images), img_patch_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -1368,8 +1330,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1368,8 +1330,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
joint = ([img_start_id] + joint_row * joint = ([img_start_id] + joint_row *
((nrows + 1) // pooling_size) + [img_end_id]) ((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint return PromptUpdateDetails.select_token_id(
return image_tokens extra_joint + joint,
embed_token_id=img_patch_id,
)
return [ return [
PromptInsertion( PromptInsertion(
...@@ -1475,11 +1439,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1475,11 +1439,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise ValueError("Incorrect type of feat_is_patch. " raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}") f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops", None) num_crops = kwargs.pop("num_crops", None)
if not isinstance(num_crops, (torch.Tensor, list)): if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_crops. "
...@@ -1491,14 +1450,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1491,14 +1450,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}") f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True) num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
image_masks=image_masks, image_masks=image_masks,
feat_is_patch=feat_is_patch, feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops, num_crops=num_crops,
) )
...@@ -1537,12 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1537,12 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -1556,7 +1508,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1556,7 +1508,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.img_patch_id, self.img_patch_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor): ...@@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor):
# when trying to find "<tile" as a subsequence of "<Image><tile" # when trying to find "<tile" as a subsequence of "<Image><tile"
repl = "<Image>" + features + "</Image>" repl = "<Image>" + features + "</Image>"
return PromptUpdateDetails(full=repl, features=repl) return PromptUpdateDetails.select_text(repl, IMG_PAD)
class NVLMProcessingInfo(BaseInternVLProcessingInfo): class NVLMProcessingInfo(BaseInternVLProcessingInfo):
...@@ -84,31 +84,6 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo): ...@@ -84,31 +84,6 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
**kwargs, **kwargs,
) )
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
tokenizer = hf_processor.tokenizer
max_num_patches = hf_processor.max_dynamic_patch
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers = [
f"<tile_{i+1}>" for i in range(max_num_patches)
]
if hf_processor.use_thumbnail and max_num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
# so we include <tile_1> in the start_str
start_str = "<Image>" + tile_pos_identifiers.pop(0)
end_str = "</Image>"
start_token_len = len(tokenizer.encode(start_str))
end_token_len = len(tokenizer.encode(end_str))
tile_token_len = sum(
len(tokenizer.encode(identifier))
for identifier in tile_pos_identifiers)
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
return super().get_max_image_tokens() + non_image_tokens_num
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
...@@ -177,10 +152,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): ...@@ -177,10 +152,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
repl = hf_processor.get_image_repl(feature_size, num_patches) repl = hf_processor.get_image_repl(feature_size, num_patches)
return PromptUpdateDetails( return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
full=repl.full + "\n",
features=repl.features + "\n",
)
# See note in dummy data regarding why we have the extra newline # See note in dummy data regarding why we have the extra newline
return [ return [
......
...@@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor( ...@@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor(
modality="image", modality="image",
target=PromptIndexTargets.prefix( target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []), [bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails( insertion=PromptUpdateDetails.select_token_id(
full=image_tokens + [bos_token_id], image_tokens + [bos_token_id],
features=image_tokens, embed_token_id=image_token_id,
), ),
) )
] ]
......
...@@ -40,8 +40,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -40,8 +40,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate)
PromptUpdateDetails)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -443,12 +442,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -443,12 +442,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
processor=hf_processor, processor=hf_processor,
) )
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens return [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptUpdateDetails(
full=image_tokens,
features=image_tokens,
)
num_images = mm_items.get_count("image", strict=False) num_images = mm_items.get_count("image", strict=False)
...@@ -517,6 +511,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -517,6 +511,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
item_idx=p.item_idx, item_idx=p.item_idx,
start_idx=p.start_idx - 1, start_idx=p.start_idx - 1,
tokens=p.tokens, tokens=p.tokens,
is_embed=p.is_embed,
) for p in ps ) for p in ps
] ]
for modality, ps in placeholders.items() for modality, ps in placeholders.items()
......
...@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ...@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (MistralTokenizer,
...@@ -46,8 +46,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer, ...@@ -46,8 +46,7 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs, from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
scatter_patch_features, select_patch_features)
try: try:
from xformers import ops as xops from xformers import ops as xops
...@@ -68,14 +67,6 @@ class PixtralImagePixelInputs(TypedDict): ...@@ -68,14 +67,6 @@ class PixtralImagePixelInputs(TypedDict):
The result of stacking :attr:`ImageEncoding.tokens` from each prompt. The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class PixtralProcessorAdapter: class PixtralProcessorAdapter:
""" """
...@@ -144,11 +135,8 @@ class PixtralProcessorAdapter: ...@@ -144,11 +135,8 @@ class PixtralProcessorAdapter:
"For more info, see: " "For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.") "https://github.com/vllm-project/vllm/issues/8411.")
image_token_id = self.image_token_id
images_processed = list[torch.Tensor]() images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]() images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
for image in images: for image in images:
image_inputs = self.image_processor(ImageChunk(image=image)) image_inputs = self.image_processor(ImageChunk(image=image))
...@@ -157,12 +145,10 @@ class PixtralProcessorAdapter: ...@@ -157,12 +145,10 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed) images_processed.append(image_processed)
images_tokens.append(image_tokens) images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
return { return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed, "images": images_processed,
"embed_is_patch": images_embed_is_patch,
} }
...@@ -213,7 +199,7 @@ class PixtralProcessingInfo(BaseProcessingInfo): ...@@ -213,7 +199,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
ncols, nrows = processor.image_processor._image_to_num_tokens( ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height))) Image.new("RGB", (image_width, image_height)))
return (ncols + 1) * nrows return ncols * nrows
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
...@@ -263,10 +249,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -263,10 +249,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
hf_inputs: Mapping[str, NestedTensors], hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(images=MultiModalFieldConfig.batched("image"))
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -290,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -290,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id tokens[-1] = image_end_id
return tokens return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -381,17 +364,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -381,17 +364,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of images. " raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}") f"Got type: {type(images)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralImagePixelInputs( return PixtralImagePixelInputs(
type="pixel_values", type="pixel_values",
images=flatten_bn(images), images=flatten_bn(images),
embed_is_patch=embed_is_patch,
) )
def _process_image_input( def _process_image_input(
...@@ -427,12 +402,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -427,12 +402,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -444,7 +414,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -444,7 +414,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.vision_args.image_token_id, self.vision_args.image_token_id,
) )
return inputs_embeds return inputs_embeds
...@@ -963,9 +933,7 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ...@@ -963,9 +933,7 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
) )
return ncols * nrows
# Consider the image_break_token
return (ncols + 1) * nrows
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_size = self.get_image_size() image_size = self.get_image_size()
......
...@@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor( ...@@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features audio_tokens = [audio_token_id] * num_features
return PromptUpdateDetails( return PromptUpdateDetails.select_token_id(
full=[audio_bos_id] + audio_tokens + [audio_eos_id], [audio_bos_id] + audio_tokens + [audio_eos_id],
features=audio_tokens, embed_token_id=audio_token_id,
) )
return [ return [
......
...@@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[img_start_id, img_end_id], target=[img_start_id, img_end_id],
replacement=PromptUpdateDetails( replacement=PromptUpdateDetails.select_token_id(
full=[img_start_id] + image_tokens + [img_end_id], [img_start_id] + image_tokens + [img_end_id],
features=image_tokens, embed_token_id=img_pad_id,
), ),
) )
] ]
......
...@@ -40,7 +40,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -40,7 +40,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -61,14 +60,6 @@ class SkyworkR1VImagePixelInputs(TypedDict): ...@@ -61,14 +60,6 @@ class SkyworkR1VImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class SkyworkR1VImageEmbeddingInputs(TypedDict): class SkyworkR1VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -419,24 +410,13 @@ class BaseSkyworkR1VProcessor(ABC): ...@@ -419,24 +410,13 @@ class BaseSkyworkR1VProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]), torch.tensor([len(item) for item in pixel_values_lst]),
} }
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst: for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0] num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches) image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text] text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
...@@ -460,7 +440,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor): ...@@ -460,7 +440,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
...@@ -599,7 +579,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -599,7 +579,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -835,7 +814,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -835,7 +814,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None) image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None: if pixel_values_flat is None and image_embeds is None:
...@@ -864,20 +842,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -864,20 +842,14 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. " raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}") f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return SkyworkR1VImagePixelInputs( return SkyworkR1VImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=self._validate_pixel_values( pixel_values_flat=self._validate_pixel_values(
pixel_values_flat), pixel_values_flat),
num_patches=image_num_patches, num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -923,15 +895,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -923,15 +895,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -945,7 +909,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -945,7 +909,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.img_context_token_id, self.img_context_token_id,
) )
return inputs_embeds return inputs_embeds
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from typing import Final, Generic, Optional, Protocol, TypeVar, Union
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -10,12 +9,9 @@ from transformers import PretrainedConfig ...@@ -10,12 +9,9 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum, from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend) get_global_forced_attn_backend)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from .interfaces import MultiModalEmbeddings
logger = init_logger(__name__) logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig) _C = TypeVar("_C", bound=PretrainedConfig)
...@@ -155,74 +151,3 @@ def resolve_visual_encoder_outputs( ...@@ -155,74 +151,3 @@ def resolve_visual_encoder_outputs(
if post_layer_norm is not None and uses_last_layer: if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs) hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1) return torch.cat(hs_pool, dim=-1)
def scatter_patch_features(
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
The rest of the values in the tensor are set to NaN so that they
can be filtered out by :func`select_patch_features`.
Args:
patches: The patch features for each image.
Shape: `(num_images, <patch_dims>, feature_depth)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one image:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 p3 p4 ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
if len(patches) != len(embed_is_patch):
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
f"{len(embed_is_patch)=}")
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
embed_one = patches_one.new_full(
(e_is_patch.shape[0], patches_one.shape[-1]),
fill_value=torch.nan,
)
embed_one[e_is_patch] = patches_one
return embed_one
return tuple(
get_embed_one(patches_one, e_is_patch)
for patches_one, e_is_patch in zip(patches, embed_is_patch))
def select_patch_features(
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
"""
Given the outputs of :func:`scatter_patch_features`, return only
the values that correspond to patch features.
"""
selected_features = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
return cast(MultiModalEmbeddings, selected_features)
...@@ -385,8 +385,8 @@ class MultiModalPlaceholderMap: ...@@ -385,8 +385,8 @@ class MultiModalPlaceholderMap:
for placeholder_dict, mm_item in zip(multi_modal_placeholders, for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items): multi_modal_items):
placeholder = range( placeholder = range(
placeholder_dict["offset"], placeholder_dict.offset,
placeholder_dict["offset"] + placeholder_dict["length"], placeholder_dict.offset + placeholder_dict.length,
) )
intersection = range( intersection = range(
max(positions.start, placeholder.start), max(positions.start, placeholder.start),
......
...@@ -109,7 +109,8 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`. ...@@ -109,7 +109,8 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
""" """
class PlaceholderRange(TypedDict): @dataclass(frozen=True)
class PlaceholderRange:
""" """
Placeholder location information for multi-modal data. Placeholder location information for multi-modal data.
...@@ -121,8 +122,8 @@ class PlaceholderRange(TypedDict): ...@@ -121,8 +122,8 @@ class PlaceholderRange(TypedDict):
.. code-block:: .. code-block::
A: { "offset": 0, "length": 4 } A: PlaceholderRange(offset=0, length=4)
B: { "offset": 5, "length": 4 } B: PlaceholderRange(offset=5, length=4)
""" """
offset: int offset: int
...@@ -131,6 +132,31 @@ class PlaceholderRange(TypedDict): ...@@ -131,6 +132,31 @@ class PlaceholderRange(TypedDict):
length: int length: int
"""The length of the placeholder.""" """The length of the placeholder."""
is_embed: Optional[torch.Tensor] = None
"""
A boolean mask of shape `(length,)` indicating which positions
between `offset` and `offset + length` to assign embeddings to.
"""
def get_num_embeds(self) -> int:
if self.is_embed is None:
return self.length
return int(self.is_embed.sum().item())
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if not (self.offset, self.length) == (other.offset, other.length):
return False
if self.is_embed is None:
return other.is_embed is None
if other.is_embed is None:
return self.is_embed is None
return nested_tensors_equal(self.is_embed, other.is_embed)
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor, NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
tuple[torch.Tensor, ...]] tuple[torch.Tensor, ...]]
......
...@@ -108,16 +108,46 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -108,16 +108,46 @@ class PromptUpdateDetails(Generic[_S]):
full: _S full: _S
"""The full content.""" """The full content."""
features: _S is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
""" """
The part of the content that corresponds to feature placeholders; Given :attr:`full`, return a boolean mask of shape `(len(full),)`
this will be replaced by the output of the vision encoder during model indicating which positions of `full` to assign embeddings to.
inference.
`None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling
:class:`SupportsMultiModal.get_multimodal_embeddings`.
""" """
@staticmethod @staticmethod
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]": def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(full=seq, features=seq) return PromptUpdateDetails(full=seq)
@staticmethod
def select_text(
seq: _S,
embed_text: str,
) -> "PromptUpdateDetails[_S]":
def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
embed_token_ids = encode_tokens(full.tokenizer, embed_text)
return torch.isin(
torch.tensor(full.token_ids),
torch.tensor(embed_token_ids),
)
return PromptUpdateDetails(full=seq, is_embed=is_embed)
@staticmethod
def select_token_id(
seq: _S,
embed_token_id: int,
) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(
full=seq,
is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
)
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
...@@ -406,7 +436,7 @@ class _BoundPromptSequence: ...@@ -406,7 +436,7 @@ class _BoundPromptSequence:
@dataclass @dataclass
class _BoundPromptContent: class _BoundPromptContent:
full: _BoundPromptSequence full: _BoundPromptSequence
features: _BoundPromptSequence is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
@dataclass @dataclass
...@@ -466,10 +496,8 @@ class BoundPromptUpdate: ...@@ -466,10 +496,8 @@ class BoundPromptUpdate:
bound_full = _BoundPromptSequence.from_seq(self.tokenizer, bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
content.full) content.full)
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
content.features)
bound_content = _BoundPromptContent(full=bound_full, bound_content = _BoundPromptContent(full=bound_full,
features=bound_features) is_embed=content.is_embed)
if cache_key is not None: if cache_key is not None:
self._content_cache[cache_key] = bound_content self._content_cache[cache_key] = bound_content
...@@ -605,15 +633,19 @@ class PlaceholderFeaturesInfo: ...@@ -605,15 +633,19 @@ class PlaceholderFeaturesInfo:
item_idx: int item_idx: int
start_idx: int start_idx: int
tokens: list[int] tokens: list[int]
is_embed: Optional[torch.Tensor]
@property @property
def length(self) -> int: def length(self) -> int:
return len(self.tokens) return len(self.tokens)
def to_range(self) -> PlaceholderRange: def to_range(self) -> PlaceholderRange:
# TODO: Is it worth it to optimize this by stripping the
# leading and ending positions where `is_embed=False`?
return PlaceholderRange( return PlaceholderRange(
offset=self.start_idx, offset=self.start_idx,
length=self.length, length=self.length,
is_embed=self.is_embed,
) )
...@@ -806,22 +838,17 @@ def _iter_placeholders( ...@@ -806,22 +838,17 @@ def _iter_placeholders(
continue continue
if prompt[start_idx:end_idx_full] == content_tokens_full: if prompt[start_idx:end_idx_full] == content_tokens_full:
content_tokens_feat = content.features.token_ids content_is_embed = content.is_embed
if content_is_embed is not None:
content_is_embed = content_is_embed(content.full)
try:
match = next(
iter_token_matches(content_tokens_full,
content_tokens_feat))
yield PlaceholderFeaturesInfo( yield PlaceholderFeaturesInfo(
modality=modality, modality=modality,
item_idx=item_idx, item_idx=item_idx,
start_idx=start_idx + match.start_idx, start_idx=start_idx,
tokens=content_tokens_feat, tokens=content_tokens_full,
is_embed=content_is_embed,
) )
except StopIteration:
raise AssertionError(
f"{content_tokens_feat=} should be a "
f"subsequence of {content_tokens_full=}") from None
# Exclude overlapping matches # Exclude overlapping matches
start_idx = end_idx_full start_idx = end_idx_full
......
...@@ -181,7 +181,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -181,7 +181,7 @@ class MultiModalProfiler(Generic[_I]):
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = { total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders) modality: sum(item.get_num_embeds() for item in placeholders)
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
expected_placeholders_by_modality = { expected_placeholders_by_modality = {
......
...@@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata( ...@@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata(
all_items.append((modality, placeholder, hash_value)) all_items.append((modality, placeholder, hash_value))
# Sort all items by offset # Sort all items by offset
all_items.sort(key=lambda x: x[1]['offset']) all_items.sort(key=lambda x: x[1].offset)
# Split into separate lists # Split into separate lists
sorted_modalities = [item[0] for item in all_items] sorted_modalities = [item[0] for item in all_items]
......
...@@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, ...@@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset. # Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of # We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase. # range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][ if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
"length"] < start_token_idx:
return extra_keys, start_mm_idx return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input. # Support start_mm_idx == -1 to indicate the last mm input.
...@@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, ...@@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx = start_mm_idx curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions): while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"] offset = mm_positions[curr_mm_idx].offset
length = mm_positions[curr_mm_idx]["length"] length = mm_positions[curr_mm_idx].length
if end_token_idx > offset: if end_token_idx > offset:
if start_token_idx > offset + length: if start_token_idx > offset + length:
# This block has passed the current mm input. # This block has passed the current mm input.
......
...@@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface): ...@@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface):
assert mm_positions is not None assert mm_positions is not None
assert len(mm_positions) > 0 assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions): for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"] start_pos = pos_info.offset
num_encoder_tokens = pos_info["length"] num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...@@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface): ...@@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface):
if cached_encoder_input_ids: if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids): for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id] mm_positions = request.mm_positions[input_id]
start_pos = mm_positions["offset"] start_pos = mm_positions.offset
num_tokens = mm_positions["length"] num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens: if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored # The encoder output is already processed and stored
# in the decoder's KV cache. # in the decoder's KV cache.
......
...@@ -121,7 +121,7 @@ class Request: ...@@ -121,7 +121,7 @@ class Request:
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions) assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id]["length"] num_tokens = self.mm_positions[input_id].length
return num_tokens return num_tokens
@property @property
......
...@@ -19,7 +19,8 @@ from vllm.logger import init_logger ...@@ -19,7 +19,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache ...@@ -43,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:
import xgrammar as xgr import xgrammar as xgr
...@@ -830,19 +832,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -830,19 +832,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
return metadata return metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = [] mm_inputs = list[MultiModalKwargs]()
req_input_ids: list[tuple[str, int]] = [] req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id]) for mm_input_id in encoder_input_ids:
req_input_ids.append((req_id, input_id)) mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
...@@ -878,16 +882,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -878,16 +882,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs.append(output) encoder_outputs.append(output)
# Cache the encoder outputs. # Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache: if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {} self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs( self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = [] mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
...@@ -895,8 +906,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -895,8 +906,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions): for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"] start_pos = pos_info.offset
num_encoder_tokens = pos_info["length"] num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, # [num_computed_tokens,
...@@ -918,8 +929,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -918,8 +929,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert req_id in self.encoder_cache assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id] assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i] encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
...@@ -984,10 +1003,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -984,10 +1003,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
else: else:
encoder_outputs = [] mm_embeds = []
# Prepare the decoder inputs. # Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = ( attn_metadata, logits_indices, spec_decode_metadata = (
...@@ -1009,9 +1028,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1009,9 +1028,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens] input_ids = self.input_ids[:num_scheduled_tokens]
if encoder_outputs: if mm_embeds:
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs) input_ids, mm_embeds)
else: else:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize. # TODO(woosuk): Avoid the copy. Optimize.
......
...@@ -19,7 +19,8 @@ from vllm.config import VllmConfig ...@@ -19,7 +19,8 @@ from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -36,7 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler ...@@ -36,7 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from .utils import sanity_check_mm_encoder_outputs from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -507,19 +509,46 @@ class TPUModelRunner: ...@@ -507,19 +509,46 @@ class TPUModelRunner:
logits_indices = logits_indices.to(self.device) logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices return attn_metadata, logits_indices
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): def _scatter_placeholders(
self,
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def _gather_placeholders(
self,
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return placeholders
return placeholders[is_embed]
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
return return
# Batch the multi-modal inputs. # Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = [] mm_inputs = list[MultiModalKwargs]()
req_input_ids: list[tuple[str, int]] = [] req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id] req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id]) for mm_input_id in encoder_input_ids:
req_input_ids.append((req_id, input_id)) mm_inputs.append(req_state.mm_inputs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has # Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one, # multiple modalities or a different modality than the previous one,
...@@ -555,16 +584,23 @@ class TPUModelRunner: ...@@ -555,16 +584,23 @@ class TPUModelRunner:
encoder_outputs.append(output) encoder_outputs.append(output)
# Cache the encoder outputs. # Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache: if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {} self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs( self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
def _gather_mm_embeddings(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = [] mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids: for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id] req_id]
...@@ -572,8 +608,8 @@ class TPUModelRunner: ...@@ -572,8 +608,8 @@ class TPUModelRunner:
num_computed_tokens = req_state.num_computed_tokens num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions): for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"] start_pos = pos_info.offset
num_encoder_tokens = pos_info["length"] num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, # [num_computed_tokens,
...@@ -595,8 +631,16 @@ class TPUModelRunner: ...@@ -595,8 +631,16 @@ class TPUModelRunner:
assert req_id in self.encoder_cache assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id] assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i] encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
@torch.no_grad() @torch.no_grad()
def execute_model( def execute_model(
...@@ -612,10 +656,10 @@ class TPUModelRunner: ...@@ -612,10 +656,10 @@ class TPUModelRunner:
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
self._execute_encoder(scheduler_output) self._execute_mm_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output)
else: else:
encoder_outputs = [] mm_embeds = []
# Prepare inputs # Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
...@@ -623,9 +667,9 @@ class TPUModelRunner: ...@@ -623,9 +667,9 @@ class TPUModelRunner:
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
if encoder_outputs: if mm_embeds:
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
self.input_ids, encoder_outputs) self.input_ids, mm_embeds)
else: else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids) inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None input_ids = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment