"examples/vscode:/vscode.git/clone" did not exist on "0f2fa9282858d7cc422a0f1bdd08684e5e703d6a"
Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
...@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.parse import ImageProcessorItems, ImageSize
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalDataItems, MultiModalDataItems, PromptReplacement,
MultiModalFieldConfig, PromptUpdate, PromptUpdateDetails)
PromptReplacement, PromptUpdate,
encode_tokens)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
# yapf: disable # yapf: disable
...@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal ...@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict): class Idefics3ImagePixelInputs(TypedDict):
...@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict): ...@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict): class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict): ...@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
...@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _resize_output_size(self, def _resize_output_size(self,
*, *,
height: int, height: int,
...@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
return grid_w * grid_h + 1 return grid_w * grid_h + 1
def _get_image_token(
self,
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_image_token = processor.global_image_tag
return image_token, fake_image_token, global_image_token
def get_image_repl( def get_image_repl(
self, self,
*, *,
...@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()
image_token = processor.image_token.content image_token, fake_image_token, global_img_token = self._get_image_token(
fake_image_token = processor.fake_image_token.content processor)
global_img_token = processor.global_image_tag
image_seq_len = processor.image_seq_len image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>" grid_placeholder = "<row_{n_h}_col_{n_w}>"
...@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
processor: Optional[Idefics3Processor], processor: Optional[Idefics3Processor],
) -> int: ) -> int:
tokenizer = self.get_tokenizer() if processor is None:
image_repl = self.get_image_repl( processor = self.get_hf_processor()
num_patches = self.get_num_patches(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
) )
image_repl_tokens = encode_tokens( return num_patches * processor.image_seq_len
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
height=image_processor.size["longest_edge"], height=image_processor.size["longest_edge"],
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
): ):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token, _, _ = self.info._get_image_token(processor)
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge'] longest_edge = image_processor.max_image_size['longest_edge']
image_token = hf_processor.image_token.content
mm_data = { return {
"image": "image":
self._get_dummy_images(width=longest_edge, self._get_dummy_images(width=longest_edge,
height=longest_edge, height=longest_edge,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Idefics3MultiModalProcessor( class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]): BaseMultiModalProcessor[Idefics3ProcessingInfo]):
...@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor( ...@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor(
] ]
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [ num_patches = [
self.info.get_num_patches( self.info.get_num_patches(
image_width=size.width, image_width=size.width,
...@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor( ...@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor(
"image", num_patches), "image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor( ...@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content image_token, _, _ = self.info._get_image_token(hf_processor)
def get_replacement_idefics3(item_idx: int) -> str: def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
return self.info.get_image_repl( image_repl = self.info.get_image_repl(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor, processor=hf_processor,
) )
return PromptUpdateDetails.select_text(
image_repl,
embed_text=image_token,
)
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
...@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None: if pixel_values is None and image_embeds is None:
return None return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
...@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs( return Idefics3ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
) )
if pixel_values is not None: if pixel_values is not None:
...@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask, pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches, num_patches=num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
] ]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_id, self.config.image_token_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol): ...@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill # Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated # TODO(ywang96): Remove this overload once v0 is deprecated
@overload @overload
......
...@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, ...@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC): ...@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]), torch.tensor([len(item) for item in pixel_values_lst]),
} }
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst: for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0] num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches) image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text] text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
...@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor): ...@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseInternVLProcessingInfo(BaseProcessingInfo): class BaseInternVLProcessingInfo(BaseProcessingInfo):
...@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ...@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ...@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
image_height=image_height, image_height=image_height,
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo) ...@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return "<image>" * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
...@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]: self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None) image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None: if pixel_values_flat is None and image_embeds is None:
...@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. " raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}") f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=self._validate_pixel_values( pixel_values_flat=self._validate_pixel_values(
pixel_values_flat), pixel_values_flat),
num_patches=image_num_patches, num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.img_context_token_id, self.img_context_token_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -294,7 +294,7 @@ class LlamaModel(nn.Module): ...@@ -294,7 +294,7 @@ class LlamaModel(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -475,7 +475,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -475,7 +475,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -523,7 +523,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -523,7 +523,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def _init_model(self, def _init_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config, return LlamaModel(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
layer_type=layer_type) layer_type=layer_type)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
is_pp_missing_parameter) is_pp_missing_parameter)
...@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module): ...@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module):
topk: int, topk: int,
renormalize: bool, renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to( router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype) hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32)) return (router_scores, router_indices.to(torch.int32))
...@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module): ...@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm( self.qk_norm = RMSNorm(
hidden_size=self.q_size, hidden_size=self.head_dim,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
has_weight=False, has_weight=False,
dtype=torch.float32, dtype=torch.float32,
...@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module): ...@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module):
if self.rotary_emb is not None: if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None: if self.qk_norm is not None:
q = self.q_norm(q.float()).to(q.dtype) q = q.reshape(-1, self.num_heads, self.head_dim)
if self.k_norm is not None: q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = self.k_norm(k.float()).to(k.dtype) k = k.reshape(-1, self.num_kv_heads, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function # to NoPE layers, where the inference-time temperature tuning function
...@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module): ...@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module):
return output return output
class Llama4DecoderLayer(LlamaDecoderLayer): class Llama4DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer): ...@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
nn.Module.__init__(self)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
...@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel): ...@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
...@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM):
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config # update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config() gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config) gen_config.update(vllm_config.model_config.override_generation_config)
# enable temperature tuning by default when max_model_len > 32K
default_attn_temperature_tuning = \
vllm_config.model_config.max_model_len > 32768
vllm_config.model_config.hf_config.attn_temperature_tuning \ vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False) = gen_config.get(
LlamaForCausalLM.__init__(self, "attn_temperature_tuning", default_attn_temperature_tuning)
vllm_config=vllm_config,
prefix=prefix, super().__init__(vllm_config=vllm_config,
layer_type=Llama4DecoderLayer) prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self, def _init_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config, return Llama4Model(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
layer_type=layer_type) layer_type=layer_type)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple
import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
disable_input_layernorm: bool,
prefix: str = "",
) -> None:
super().__init__(config, prefix=prefix)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if disable_input_layernorm:
del self.input_layernorm
self.input_layernorm = nn.Identity()
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
return hidden_states + residual
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
...@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
...@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel ...@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features, from .vision import get_vision_encoder_info
select_patch_features)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
self, self,
strategy: str, strategy: str,
...@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) ...@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class LlavaProcessingInfo(BaseLlavaProcessingInfo): class LlavaProcessingInfo(BaseLlavaProcessingInfo):
...@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor( ...@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes) for p, (h, w) in zip(pixel_values, image_sizes)
] ]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor( ...@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor( ...@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id tokens[-1] = image_end_id
return tokens return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
...@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes) image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
...@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1} return {"video": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
...@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
class LlavaNextVideoDummyInputsBuilder( class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor() processor = self.info.get_hf_processor()
video_token = processor.video_token video_token = processor.video_token
return video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_videos = mm_counts.get("video", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { return {
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
...@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder( ...@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder(
) )
} }
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
class LlavaNextVideoMultiModalProcessor( class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
...@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return [e.flatten(0, 1) for e in embeds] return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
......
...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn ...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
...@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor # with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features( def _get_num_unpadded_features(
...@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
class LlavaOnevisionDummyInputsBuilder( class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder( ...@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder(
image_token = processor.image_token image_token = processor.image_token
video_token = processor.video_token video_token = processor.video_token
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, self.info.get_num_frames_with_most_features(seq_len,
mm_counts) mm_counts)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
...@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder( ...@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder(
) )
} }
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class LlavaOnevisionMultiModalProcessor( class LlavaOnevisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
...@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim) image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
...@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group ...@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim, head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon, rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act, activation=config.hidden_act,
chunk_size=config.chunk_size,
quant_config=quant_config) quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor], mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params, hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx) mamba2_metadata)
return hidden_states, residual return hidden_states, residual
...@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module): ...@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) mamba2_metadata = prepare_mamba2_metadata(
for i, (srt, end) in enumerate( chunk_size=self.config.chunk_size,
zip( input_ids=input_ids,
attn_metadata.query_start_loc, attn_metadata=attn_metadata,
attn_metadata.query_start_loc[1:], )
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module): ...@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer), i - self.start_layer),
sequence_idx=seq_idx) mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import ( ...@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
DictEmbeddingItems, ModalityData, DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems, ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
from vllm.multimodal.profiling import ProcessorInputs PromptUpdateDetails)
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
MiniCPMVDummyInputsBuilder, MiniCPMVDummyInputsBuilder,
...@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, ...@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix) maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
...@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict): ...@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
...@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict): ...@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list. Length of each slice may vary, so pass it as a list.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs] MiniCPMOAudioEmbeddingInputs]
...@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features=MultiModalFieldConfig.batched("audio"), audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
) )
...@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): ...@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data( def _parse_audio_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems( return MiniCPMOAudioEmbeddingItems(
data, data,
...@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {**super().get_supported_mm_limits(), "audio": None} return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio":
self.get_max_audio_tokens(),
}
def get_audio_placeholder( def get_audio_placeholder(
self, self,
audio_lens: int, audio_lens: int,
...@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step = self.get_default_audio_pool_step() pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100 fbank_feat_in_chunk = 100
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 return (cnn_feat_in_chunk - pool_step) // pool_step + 1
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
def get_max_audio_chunks_with_most_features(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int:
return 30 return 30
...@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate() sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio> num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features( def get_num_frames_with_most_features(
...@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
class MiniCPMODummyInputsBuilder( class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self, seq_len: int, mm_counts: Mapping[str, num_audios = mm_counts.get("audio", 0)
int]) -> ProcessorInputs:
audio_prompt_texts = self.info.audio_pattern * num_audios
return super().get_dummy_text(mm_counts) + audio_prompt_texts
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_len = self.info.get_max_audio_chunks_with_most_features() * \ audio_len = self.info.get_max_audio_chunks_with_most_features() * \
self.info.get_default_audio_sampling_rate() self.info.get_default_audio_sampling_rate()
processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts)
audio_prompt_texts = self.info.audio_pattern * num_audios
audio_mm_data = { audio_mm_data = {
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len, num_audios=num_audios)
} }
return ProcessorInputs( return {
prompt_text=processor_inputs.prompt_text + audio_prompt_texts, **super().get_dummy_mm_data(seq_len, mm_counts),
mm_data={ **audio_mm_data,
**processor_inputs.mm_data, }
**audio_mm_data,
},
)
class MiniCPMOMultiModalProcessor( class MiniCPMOMultiModalProcessor(
...@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor( ...@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor(
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {} audio_inputs = {}
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else: else:
audio_inputs = self._base_call_hf_processor( audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios), prompts=[self.info.audio_pattern] * len(parsed_audios),
...@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor( ...@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor(
] ]
audio_inputs["audio_features"] = unpadded_audio_features audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
...@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor( ...@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor(
else: else:
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
return self.get_audio_prompt_texts(audio_len) return PromptUpdateDetails.select_text(
self.get_audio_prompt_texts(audio_len),
"<unk>",
)
return [ return [
*base_updates, *base_updates,
...@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6):
assert isinstance(audio_token_id, torch.Tensor) assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item()) self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)): if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. " raise ValueError("Incorrect type of audio_embeds. "
...@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6):
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds", type="audio_embeds",
audio_embeds=audio_embeds_flat, audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
) )
if not isinstance(audio_features, (torch.Tensor, list)): if not isinstance(audio_features, (torch.Tensor, list)):
...@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6):
type="audio_features", type="audio_features",
audio_features=audio_features_flat, audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat, audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
...@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6):
if modality == "audios": if modality == "audios":
audio_input = modalities["audios"] audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input) audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(audio_features)
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
return multimodal_embeddings return multimodal_embeddings
...@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys ...@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize, ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems, ModalityData, ModalityDataItems,
...@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ...@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems) VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists from vllm.utils import flatten_2d_lists
...@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, ...@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 16
...@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict): ...@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format. This should be in `(height, width)` format.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices: torch.Tensor num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
...@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): ...@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
instead of a batched tensor. instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs] MiniCPMVImageEmbeddingInputs]
...@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_sizes=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"), tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"), video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"),
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos), video_token_id=MultiModalFieldConfig.shared("video", num_videos),
) )
...@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data( def _parse_image_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems( return MiniCPMVImageEmbeddingItems(
data, data,
...@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_video_data( def _parse_video_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems( return MiniCPMVVideoEmbeddingItems(
data, data,
...@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return mm_limits return mm_limits
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(
seq_len, mm_counts)
return mm_max_tokens
def get_slice_image_placeholder( def get_slice_image_placeholder(
self, self,
image_size: ImageSize, image_size: ImageSize,
...@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id=use_image_id, use_image_id=use_image_id,
) )
def get_sliced_grid(
self,
image_size: ImageSize,
# For MiniCPM V/O 2.6
max_slice_nums: Optional[int] = None,
) -> Optional[tuple[int, int]]:
image_processor = self.get_image_processor()
version = self.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_sliced_grid(image_size)
if max_slice_nums is None:
max_slice_nums = image_processor.max_slice_nums
return image_processor.get_sliced_grid(
image_size,
max_slice_nums=max_slice_nums,
)
def get_num_image_tokens( def get_num_image_tokens(
self, self,
image_size: ImageSize, image_size: ImageSize,
max_slice_nums: Optional[int] = None, max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> int: ) -> int:
tokenizer = self.get_tokenizer() image_processor = self.get_image_processor()
image_placeholders = self.get_slice_image_placeholder(
grid = self.get_sliced_grid(
image_size, image_size,
max_slice_nums=max_slice_nums, max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
) )
image_token_ids = tokenizer.encode(image_placeholders, if grid is None:
add_special_tokens=False) ncols = nrows = 0
else:
ncols, nrows = grid
return len(image_token_ids) return (ncols * nrows + 1) * image_processor.image_feature_size
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features() image_size = self.get_image_size_with_most_features()
...@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_num_image_tokens( return self.get_num_image_tokens(
frame_size, frame_size,
max_slice_nums=self.get_video_max_slice_num(), max_slice_nums=self.get_video_max_slice_num(),
use_image_id=False,
) )
def get_max_video_tokens( def get_max_video_tokens(
...@@ -482,11 +472,20 @@ _I = TypeVar("_I", ...@@ -482,11 +472,20 @@ _I = TypeVar("_I",
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_prompt_texts = self.info.image_pattern * num_images
video_prompt_texts = self.info.video_pattern * num_videos
return image_prompt_texts + video_prompt_texts
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
num_video_frames = \ num_video_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=image_width, self._get_dummy_images(width=image_width,
height=image_height, height=image_height,
...@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
] * num_videos, ] * num_videos,
} }
image_prompt_texts = self.info.image_pattern * num_images
video_prompt_texts = self.info.video_pattern * num_videos
return ProcessorInputs(prompt_text=image_prompt_texts +
video_prompt_texts,
mm_data=mm_data)
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
use_image_id=False, use_image_id=False,
) * num_frames ) * num_frames
def get_embed_is_patch(
self,
input_ids: list[int],
) -> torch.Tensor:
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
return torch.tensor(input_ids) == unk_token_id
def process_images( def process_images(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
...@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_repl_features = [
self.get_image_prompt_texts(size, idx)
for idx, size in enumerate(image_sizes)
]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(image_repl_tokens)
for image_repl_tokens in image_repls_feature_tokens
]
image_inputs["embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id) image_inputs["image_token_id"] = torch.tensor(unk_token_id)
...@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
frame_sizes = [
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
]
num_frames = [
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
]
video_repl_features = [
self.get_video_prompt_texts(size, nframes)
for size, nframes in zip(frame_sizes, num_frames)
]
tokenizer = self.info.get_tokenizer()
video_repls_feature_tokens = [
tokenizer.encode(video_repl, add_special_tokens=False)
for video_repl in video_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(video_repl_tokens)
for video_repl_tokens in video_repls_feature_tokens
]
video_inputs["embed_is_patch"] = embed_is_patch
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id) video_inputs["video_token_id"] = torch.tensor(unk_token_id)
...@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
return self.get_image_prompt_texts(image_size, item_idx) return PromptUpdateDetails.select_text(
self.get_image_prompt_texts(image_size, item_idx),
"<unk>",
)
def get_video_replacement(item_idx: int): def get_video_replacement(item_idx: int):
videos = mm_items.get_items( videos = mm_items.get_items(
...@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
frame_size = videos.get_frame_size(item_idx) frame_size = videos.get_frame_size(item_idx)
num_frames = videos.get_num_frames(item_idx) num_frames = videos.get_num_frames(item_idx)
return self.get_video_prompt_texts(frame_size, num_frames) return PromptUpdateDetails.select_text(
self.get_video_prompt_texts(frame_size, num_frames),
"<unk>",
)
get_replacement = { get_replacement = {
"image": get_image_replacement, "image": get_image_replacement,
...@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
assert isinstance(image_token_id, torch.Tensor) assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item()) self.mm_token_ids.add(image_token_id.flatten().unique().item())
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of embed_is_patch for {modality=}. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError( raise ValueError(
...@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return MiniCPMVImageEmbeddingInputs( return MiniCPMVImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=image_embeds_flat, image_embeds=image_embeds_flat,
embed_is_patch=embed_is_patch,
) )
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
...@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
type="pixel_values", type="pixel_values",
pixel_values=pixel_values_flat, pixel_values=pixel_values_flat,
tgt_sizes=tgt_sizes_flat, tgt_sizes=tgt_sizes_flat,
embed_is_patch=embed_is_patch,
num_slices=num_slices_flat, num_slices=num_slices_flat,
) )
...@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
image_features = self._process_vision_input(image_input) image_features = self._process_vision_input(image_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(image_features)
scatter_patch_features(
image_features,
image_input["embed_is_patch"],
))
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_features = self._process_vision_input(video_input) video_features = self._process_vision_input(video_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(video_features)
scatter_patch_features(
video_features,
video_input["embed_is_patch"],
))
return multimodal_embeddings return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
...@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
list(self.mm_token_ids), list(self.mm_token_ids),
) )
return inputs_embeds return inputs_embeds
......
...@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features, from .vision import get_vision_encoder_info
select_patch_features)
class Mistral3ImagePixelInputs(TypedDict): class Mistral3ImagePixelInputs(TypedDict):
...@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict): ...@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
class Mistral3PatchMerger(nn.Module): class Mistral3PatchMerger(nn.Module):
""" """
...@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Mistral3ProcessingInfo(BaseLlavaProcessingInfo): class Mistral3ProcessingInfo(BaseLlavaProcessingInfo):
...@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor( ...@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor(
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes) p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
] ]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor( ...@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor( ...@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id tokens[-1] = image_end_id
return tokens return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -418,8 +379,6 @@ def init_vision_tower_for_llava( ...@@ -418,8 +379,6 @@ def init_vision_tower_for_llava(
) )
# TODO(mgoin): Support V1, there are issues with image batching/chunking
# that need to be resolved first.
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
_build_mistral3_processor, _build_mistral3_processor,
info=_build_mistral3_info, info=_build_mistral3_info,
...@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
assert self.config.vision_config.model_type == "pixtral"
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
return Mistral3ImagePixelInputs( return Mistral3ImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=flatten_bn(embed_is_patch),
) )
def _process_image_input( def _process_image_input(
...@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = (image_embeds, ) image_embeds = (image_embeds, )
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
...@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return scatter_patch_features( return vision_embeddings
vision_embeddings,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
...@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -260,6 +260,8 @@ class MixtralModel(nn.Module): ...@@ -260,6 +260,8 @@ class MixtralModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
...@@ -313,88 +315,6 @@ class MixtralModel(nn.Module): ...@@ -313,88 +315,6 @@ class MixtralModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -415,9 +335,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -415,9 +335,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales # Loading kv cache quantization scales
...@@ -489,3 +406,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -489,3 +406,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
return loader.load_weights(weights)
...@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
...@@ -52,16 +52,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -52,16 +52,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalKwargs) MultiModalFieldConfig, MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only from .interfaces import SupportsMultiModal, SupportsV0Only
...@@ -106,16 +107,6 @@ class MllamaProcessingInfo(BaseProcessingInfo): ...@@ -106,16 +107,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
image_size = self.get_hf_config().vision_config.image_size image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size) return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int, def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int: image_width: int) -> int:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
...@@ -141,31 +132,31 @@ class MllamaProcessingInfo(BaseProcessingInfo): ...@@ -141,31 +132,31 @@ class MllamaProcessingInfo(BaseProcessingInfo):
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]): class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
): ):
...@@ -211,6 +202,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -211,6 +202,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
# } # }
if mm_data: if mm_data:
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
# Since only the last group of consecutive images # Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to # are attended by the decoded tokens, we only need to
# get the number of tokens for those images. # get the number of tokens for those images.
...@@ -227,7 +221,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -227,7 +221,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
num_tokens = decode_tiles * token_per_chunk num_tokens = decode_tiles * token_per_chunk
mm_inputs["encoder_prompt_token_ids"] = [image_token_id mm_inputs["encoder_prompt_token_ids"] = [image_token_id
] * num_tokens ] * num_tokens
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens mm_inputs["encoder_prompt"] = image_token * num_tokens
return mm_inputs return mm_inputs
...@@ -1188,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1188,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__() super().__init__()
config: MllamaConfig = vllm_config.model_config.hf_config config: MllamaConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
...@@ -1306,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1306,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _get_and_validate_encoder_lens(
self,
encoder_seq_lens: List[int],
num_tiles: List[List[int]],
num_tokens_per_tile: int,
) -> List[int]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
attn_metadata_lens):
assert actual_len >= last_group_len
return actual_encoder_seq_lens
def flat_encoder_result(self, cross_attention_states: torch.Tensor, def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]): actual_encoder_seq_lens: List[int]):
...@@ -1325,6 +1345,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1325,6 +1345,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
cross_attention_states = cross_attention_states_flat cross_attention_states = cross_attention_states_flat
return cross_attention_states return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states( def get_cross_attention_states(
self, self,
image_inputs: MllamaImagePixelInputs, image_inputs: MllamaImagePixelInputs,
...@@ -1430,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1430,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
else: else:
skip_cross_attention = False skip_cross_attention = False
# Get the actual number of encoder tokens for each sample. num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tokens_per_tile = calc_token_per_chunk(self.image_size) num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
] attn_metadata.encoder_seq_lens,
for actual_len, last_group_len in zip( num_tiles,
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens): num_tokens_per_tile,
assert actual_len >= last_group_len )
cross_attention_states = self.get_cross_attention_states( cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens) image_inputs, attn_metadata, actual_encoder_seq_lens)
...@@ -1521,6 +1538,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1521,6 +1538,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
updated_params.add(name) updated_params.add(name)
return updated_params return updated_params
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_model")
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask: for mask in sparse_mask:
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
# limitations under the License. # limitations under the License.
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from functools import cached_property
from itertools import tee from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
...@@ -24,7 +25,6 @@ import torch ...@@ -24,7 +25,6 @@ import torch
from torch import nn from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict from transformers.image_utils import SizeDict
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.llama4 import Llama4Processor from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import ( from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions, get_best_fit) find_supported_resolutions, get_best_fit)
...@@ -33,33 +33,30 @@ from vllm.attention.layer import MultiHeadAttention ...@@ -33,33 +33,30 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import MultiModalEmbeddings, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .llama4 import Llama4ForCausalLM
maybe_prefix, merge_multimodal_embeddings) from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
from .vision import scatter_patch_features, select_patch_features merge_multimodal_embeddings)
logger = init_logger(__name__)
class Llama4ImagePatchInputs(TypedDict): class Llama4ImagePatchInputs(TypedDict):
...@@ -76,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict): ...@@ -76,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]] aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
""" """
A list of aspect ratios corresponding to the number of tiles A list of aspect ratios corresponding to the number of tiles
...@@ -345,7 +338,7 @@ class Llama4VisionEncoder(nn.Module): ...@@ -345,7 +338,7 @@ class Llama4VisionEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> BaseModelOutput: ) -> torch.Tensor:
r""" r"""
Args: Args:
inputs_embeds (`torch.FloatTensor` of shape inputs_embeds (`torch.FloatTensor` of shape
...@@ -361,7 +354,7 @@ class Llama4VisionEncoder(nn.Module): ...@@ -361,7 +354,7 @@ class Llama4VisionEncoder(nn.Module):
layer_outputs = encoder_layer(hidden_states) layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
return BaseModelOutput(last_hidden_state=hidden_states, ) return hidden_states
class Llama4UnfoldConvolution(nn.Module): class Llama4UnfoldConvolution(nn.Module):
...@@ -433,7 +426,7 @@ class Llama4VisionModel(nn.Module): ...@@ -433,7 +426,7 @@ class Llama4VisionModel(nn.Module):
def forward( def forward(
self, self,
images_flattened: torch.Tensor, images_flattened: torch.Tensor,
) -> BaseModelOutput: ) -> torch.Tensor:
# Patch embedding # Patch embedding
hidden_state = self.patch_embedding(images_flattened) hidden_state = self.patch_embedding(images_flattened)
num_tiles, num_patches, hidden_dim = hidden_state.shape num_tiles, num_patches, hidden_dim = hidden_state.shape
...@@ -458,8 +451,7 @@ class Llama4VisionModel(nn.Module): ...@@ -458,8 +451,7 @@ class Llama4VisionModel(nn.Module):
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
# Apply encoder # Apply encoder
output = self.model(hidden_state) hidden_state = self.model(hidden_state)
hidden_state = output.last_hidden_state
hidden_state = self.layernorm_post(hidden_state) hidden_state = self.layernorm_post(hidden_state)
# Remove CLS token output # Remove CLS token output
...@@ -468,10 +460,7 @@ class Llama4VisionModel(nn.Module): ...@@ -468,10 +460,7 @@ class Llama4VisionModel(nn.Module):
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state = self.vision_adapter(hidden_state) hidden_state = self.vision_adapter(hidden_state)
return BaseModelOutput( return hidden_state
last_hidden_state=hidden_state,
attentions=None,
)
class Mllama4ProcessingInfo(BaseProcessingInfo): class Mllama4ProcessingInfo(BaseProcessingInfo):
...@@ -488,7 +477,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -488,7 +477,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
**kwargs) **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 10} # Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.
return {"image": None}
@staticmethod @staticmethod
def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int: def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
...@@ -507,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -507,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
# image_start + local tiles * (patches + 1 x separator) +
# 1 global tile * (image x 1 + patches) + image_end
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
return {"image": mm_max_tokens}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
...@@ -581,33 +561,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -581,33 +561,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for (r_h, r_w) in aspect_ratios for (r_h, r_w) in aspect_ratios
] ]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
num_patches_per_chunk = self.info.get_patch_per_chunk(
vision_config)
expanded_image_tokens_list = [
processor._prompt_split_image(aspect_ratio,
num_patches_per_chunk)
for aspect_ratio in aspect_ratios
]
expanded_image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in expanded_image_tokens_list
]
embed_is_patch = [
torch.tensor(tokens) == patch_id
for tokens in expanded_image_token_ids
]
processed_outputs["aspect_ratios"] = aspect_ratios processed_outputs["aspect_ratios"] = aspect_ratios
processed_outputs["patches_per_image"] = torch.tensor( processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image) patches_per_image)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
...@@ -622,7 +578,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -622,7 +578,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
"image", patches_per_image), "image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"), patches_per_image=MultiModalFieldConfig.batched("image"),
aspect_ratios=MultiModalFieldConfig.batched("image"), aspect_ratios=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -642,12 +597,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -642,12 +597,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config) num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token image_token = hf_processor.image_token
img_patch_token = hf_processor.img_patch_token
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx] aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
return hf_processor._prompt_split_image(
repl = hf_processor._prompt_split_image(
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
num_patches_per_chunk=num_patches_per_chunk) num_patches_per_chunk=num_patches_per_chunk,
)
return PromptUpdateDetails.select_text(repl, img_patch_token)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -660,36 +620,39 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -660,36 +620,39 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.fake_image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
(target_width, (target_width,
target_height) = self.info.get_image_size_with_most_features() target_height) = self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
image_token = self.info.get_hf_processor().fake_image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Mllama4MultiModalProcessor, Mllama4MultiModalProcessor,
info=Mllama4ProcessingInfo, info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder, dummy_inputs=Mllama4DummyInputsBuilder,
) )
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
} }
...@@ -710,13 +673,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -710,13 +673,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
self.config, self.config,
None, None,
prefix=maybe_prefix(prefix, "multi_modal_projector")) prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
architectures=["Llama4ForCausalLM"],
prefix=maybe_prefix(prefix, "language_model"))
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) self.language_model = _initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
...@@ -730,11 +702,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -730,11 +702,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
flat_pixel_values = flatten_bn(pixel_values, concat=True) flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image")) patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
aspect_ratios = kwargs.pop("aspect_ratios", None) aspect_ratios = kwargs.pop("aspect_ratios", None)
if not isinstance(aspect_ratios, (torch.Tensor, list)): if not isinstance(aspect_ratios, (torch.Tensor, list)):
raise ValueError("Incorrect type of aspect_ratios. " raise ValueError("Incorrect type of aspect_ratios. "
...@@ -744,7 +711,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -744,7 +711,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
type="pixel_values", type="pixel_values",
flat_data=flat_pixel_values, flat_data=flat_pixel_values,
patches_per_image=patches_per_image, patches_per_image=patches_per_image,
embed_is_patch=embed_is_patch,
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
) )
...@@ -752,8 +718,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -752,8 +718,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
flat_data = image_input["flat_data"] flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist() patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data).last_hidden_state
return vision_embeddings_flat.split(patches_per_image, dim=0) vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat)
return [
img.flatten(0, 1)
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]: **kwargs) -> Optional[MultiModalEmbeddings]:
...@@ -761,20 +737,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -761,20 +737,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
if image_input is None: if image_input is None:
return None return None
# num_images x [num_chunks, num_patches, hidden_dim] return self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat = [img.flatten(0, 1) for img in image_features]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat = [
is_patch.flatten(0, 1)
for is_patch in image_input["embed_is_patch"]
]
return scatter_patch_features(
image_features_flat,
embed_is_patch_flat,
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -784,11 +747,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -784,11 +747,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings)
mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, select_patch_features(mm_embeddings), input_ids,
self.config.image_token_index) inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds return inputs_embeds
...@@ -800,9 +764,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -800,9 +764,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
# NOTE: In v1, inputs_embeds is always generated at model runner, this if intermediate_tensors is not None:
# condition is for v0 compatibility. inputs_embeds = None
if "pixel_values" in kwargs:
# NOTE: In v1, inputs_embeds is always generated at model runner,
# this condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs) vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings) vision_embeddings)
...@@ -857,9 +824,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -857,9 +824,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
# language_model is an Llama4ForCausalLM instance. We load it's # language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine. # using llama4's load_weights routine.
language_model_prefix = "language_model.model."
language_model_weights, other_weights = self.separate_weights( language_model_weights, other_weights = self.separate_weights(
weights, prefix=language_model_prefix) weights, prefix="language_model.model.")
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights( loaded_language_model_params = loader.load_weights(
language_model_weights) language_model_weights)
......
...@@ -41,13 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,13 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...@@ -56,7 +58,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -56,7 +58,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
...@@ -84,14 +85,6 @@ class MolmoImageInputs(TypedDict): ...@@ -84,14 +85,6 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_crops, num_patch)` Shape: `(batch_size * num_images, num_crops, num_patch)`
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: torch.Tensor num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
...@@ -1146,30 +1139,6 @@ class MolmoProcessorWrapper: ...@@ -1146,30 +1139,6 @@ class MolmoProcessorWrapper:
if image_input_idx is not None: if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0 feat_is_patch = image_input_idx >= 0
input_is_embed = torch.isin(
input_ids,
torch.tensor([
self.image_patch_id,
self.im_col_id,
self.im_start_id,
self.im_end_id,
]),
)
embed_ids = input_ids[input_is_embed]
embed_is_patch = embed_ids == self.image_patch_id
assert embed_is_patch.sum() == feat_is_patch.sum()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
assert len(embed_start) == len(embed_end) == len(images)
embed_is_patch = [
embed_is_patch[start:end + 1]
for start, end in zip(embed_start, embed_end)
]
tilings = [ tilings = [
self.select_tiling( self.select_tiling(
image_width=image.size[0], image_width=image.size[0],
...@@ -1181,7 +1150,6 @@ class MolmoProcessorWrapper: ...@@ -1181,7 +1150,6 @@ class MolmoProcessorWrapper:
assert num_crops.sum() == len(feat_is_patch) assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch outputs["feat_is_patch"] = feat_is_patch
outputs["embed_is_patch"] = embed_is_patch
outputs["num_crops"] = num_crops outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id outputs["img_patch_id"] = self.image_patch_id
...@@ -1197,13 +1165,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1197,13 +1165,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -1220,26 +1181,13 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1220,26 +1181,13 @@ class MolmoProcessingInfo(BaseProcessingInfo):
) )
pooling_size = processor.pooling_size pooling_size = processor.pooling_size
base_image_input_size = processor.base_image_input_size image_token_length_w = processor.image_token_length_w
base_image_input_d = processor.image_patch_size image_token_length_h = processor.image_token_length_h
crop_patches = base_image_input_size[0] // base_image_input_d
per_row = ncols // pooling_size + 1
joint = per_row * (nrows // pooling_size) + 2
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
resize = (image_token_length + 1) * image_token_length + 2
return resize + joint extra = image_token_length_w * image_token_length_h
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
def get_max_image_tokens(self) -> int: return extra + joint
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -1269,27 +1217,25 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1269,27 +1217,25 @@ class MolmoProcessingInfo(BaseProcessingInfo):
class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...@@ -1328,7 +1274,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1328,7 +1274,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image", num_crops), "image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes( feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops), "image", num_crops),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images), img_patch_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -1368,8 +1313,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1368,8 +1313,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
joint = ([img_start_id] + joint_row * joint = ([img_start_id] + joint_row *
((nrows + 1) // pooling_size) + [img_end_id]) ((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint return PromptUpdateDetails.select_token_id(
return image_tokens extra_joint + joint,
embed_token_id=img_patch_id,
)
return [ return [
PromptInsertion( PromptInsertion(
...@@ -1475,11 +1422,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1475,11 +1422,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise ValueError("Incorrect type of feat_is_patch. " raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}") f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops", None) num_crops = kwargs.pop("num_crops", None)
if not isinstance(num_crops, (torch.Tensor, list)): if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_crops. "
...@@ -1491,14 +1433,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1491,14 +1433,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}") f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True) num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
image_masks=image_masks, image_masks=image_masks,
feat_is_patch=feat_is_patch, feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops, num_crops=num_crops,
) )
...@@ -1531,18 +1471,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1531,18 +1471,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
) )
] ]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -1556,7 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1556,7 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.img_patch_id, self.img_patch_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -15,12 +15,11 @@ from transformers import PretrainedConfig ...@@ -15,12 +15,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor, from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
...@@ -57,7 +56,7 @@ class NVLMProcessor(BaseInternVLProcessor): ...@@ -57,7 +56,7 @@ class NVLMProcessor(BaseInternVLProcessor):
# when trying to find "<tile" as a subsequence of "<Image><tile" # when trying to find "<tile" as a subsequence of "<Image><tile"
repl = "<Image>" + features + "</Image>" repl = "<Image>" + features + "</Image>"
return PromptUpdateDetails(full=repl, features=repl) return PromptUpdateDetails.select_text(repl, IMG_PAD)
class NVLMProcessingInfo(BaseInternVLProcessingInfo): class NVLMProcessingInfo(BaseInternVLProcessingInfo):
...@@ -84,57 +83,32 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo): ...@@ -84,57 +83,32 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
**kwargs, **kwargs,
) )
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
tokenizer = hf_processor.tokenizer
max_num_patches = hf_processor.max_dynamic_patch class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers = [
f"<tile_{i+1}>" for i in range(max_num_patches)
]
if hf_processor.use_thumbnail and max_num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
# "<Image><tile" is tokenized as ["<Image", "><", "tile"] def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
# so we include <tile_1> in the start_str num_images = mm_counts.get("image", 0)
start_str = "<Image>" + tile_pos_identifiers.pop(0)
end_str = "</Image>"
start_token_len = len(tokenizer.encode(start_str))
end_token_len = len(tokenizer.encode(end_str))
tile_token_len = sum(
len(tokenizer.encode(identifier))
for identifier in tile_pos_identifiers)
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
return super().get_max_image_tokens() + non_image_tokens_num
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
return "<image>\n" * num_images
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]): def get_dummy_mm_data(
def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
prompt_text="<image>\n" * num_images,
mm_data=mm_data,
)
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
...@@ -177,10 +151,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]): ...@@ -177,10 +151,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
repl = hf_processor.get_image_repl(feature_size, num_patches) repl = hf_processor.get_image_repl(feature_size, num_patches)
return PromptUpdateDetails( return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
full=repl.full + "\n",
features=repl.features + "\n",
)
# See note in dummy data regarding why we have the extra newline # See note in dummy data regarding why we have the extra newline
return [ return [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment