"examples/backends/trtllm/deploy/agg_router.yaml" did not exist on "ccc8c6270e2edea3c1051b07d10b663b5ba5d761"
Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
......@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
encode_tokens)
MultiModalDataItems, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
# yapf: disable
......@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict):
......@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
......@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
......@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _resize_output_size(self,
*,
height: int,
......@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
return grid_w * grid_h + 1
def _get_image_token(
self,
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_image_token = processor.global_image_tag
return image_token, fake_image_token, global_image_token
def get_image_repl(
self,
*,
......@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_img_token = processor.global_image_tag
image_token, fake_image_token, global_img_token = self._get_image_token(
processor)
image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>"
......@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
if processor is None:
processor = self.get_hf_processor()
num_patches = self.get_num_patches(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
return num_patches * processor.image_seq_len
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
......@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
height=image_processor.size["longest_edge"],
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token, _, _ = self.info._get_image_token(processor)
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge']
image_token = hf_processor.image_token.content
mm_data = {
return {
"image":
self._get_dummy_images(width=longest_edge,
height=longest_edge,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]):
......@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor(
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [
self.info.get_num_patches(
image_width=size.width,
......@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
......@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
image_token, _, _ = self.info._get_image_token(hf_processor)
def get_replacement_idefics3(item_idx: int) -> str:
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
image_repl = self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return PromptUpdateDetails.select_text(
image_repl,
embed_text=image_token,
)
return [
PromptReplacement(
modality="image",
......@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None:
return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
......@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
......@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
......@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
return self._process_image_input(image_input)
def get_input_embeddings(
self,
......@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds
......
......@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
"""
...
def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
......
......@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>'
IMG_END = '</img>'
......@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
......@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]),
}
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
......@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features)
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseInternVLProcessingInfo(BaseProcessingInfo):
......@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
......@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
image_height=image_height,
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
......@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return "<image>" * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
......@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
......@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
......@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
......@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
return self._process_image_input(image_input)
def get_input_embeddings(
self,
......@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
self.img_context_token_id,
)
return inputs_embeds
......
......@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -294,7 +294,7 @@ class LlamaModel(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -475,7 +475,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -523,7 +523,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
......
......@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
......@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index,
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
is_pp_missing_parameter)
......@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module):
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32))
......@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm(
hidden_size=self.q_size,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
self.qk_norm = RMSNorm(
hidden_size=self.head_dim,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
......@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module):
if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None:
q = self.q_norm(q.float()).to(q.dtype)
if self.k_norm is not None:
k = self.k_norm(k.float()).to(k.dtype)
if self.qk_norm is not None:
q = q.reshape(-1, self.num_heads, self.head_dim)
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
......@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module):
return output
class Llama4DecoderLayer(LlamaDecoderLayer):
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
......@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
......@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config,
prefix=prefix,
......@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM):
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config
# update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
# enable temperature tuning by default when max_model_len > 32K
default_attn_temperature_tuning = \
vllm_config.model_config.max_model_len > 32768
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False)
LlamaForCausalLM.__init__(self,
vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
= gen_config.get(
"attn_temperature_tuning", default_attn_temperature_tuning)
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer):
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple
import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
disable_input_layernorm: bool,
prefix: str = "",
) -> None:
super().__init__(config, prefix=prefix)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if disable_input_layernorm:
del self.input_layernorm
self.input_layernorm = nn.Identity()
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
return hidden_states + residual
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
......@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
......@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
from .vision import get_vision_encoder_info
class LlavaImagePixelInputs(TypedDict):
......@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
......@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(
self,
strategy: str,
......@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
......@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes)
]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _get_mm_fields_config(
......@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
......@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return tokens
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [
PromptReplacement(
......@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
)
return LlavaImagePixelInputs(
......@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
return self._process_image_input(image_input)
def get_input_embeddings(
self,
......@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
......
......@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
......@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor()
video_token = processor.video_token
return video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_videos = mm_counts.get("video", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = {
return {
"video":
self._get_dummy_videos(
width=target_width,
......@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder(
)
}
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
......@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs)
......
......@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
......@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features(
......@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
......@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder(
image_token = processor.image_token
video_token = processor.video_token
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len,
mm_counts)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
......@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder(
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class LlavaOnevisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
......@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
......@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import (
......@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
chunk_size=config.chunk_size,
quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor],
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
if residual is None:
......@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx)
mamba2_metadata)
return hidden_states, residual
......@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
for i in range(len(self.layers)):
layer = self.layers[i]
......@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
sequence_idx=seq_idx)
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......
......@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import (
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
MiniCPMVDummyInputsBuilder,
......@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu")
......@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
......@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs]
......@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
)
......@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(
data,
......@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio":
self.get_max_audio_tokens(),
}
def get_audio_placeholder(
self,
audio_lens: int,
......@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
return (cnn_feat_in_chunk - pool_step) // pool_step + 1
def get_max_audio_chunks_with_most_features(self) -> int:
return 30
......@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio>
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features(
......@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
def get_dummy_processor_inputs(
self, seq_len: int, mm_counts: Mapping[str,
int]) -> ProcessorInputs:
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
audio_prompt_texts = self.info.audio_pattern * num_audios
return super().get_dummy_text(mm_counts) + audio_prompt_texts
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
audio_len = self.info.get_max_audio_chunks_with_most_features() * \
self.info.get_default_audio_sampling_rate()
processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts)
audio_prompt_texts = self.info.audio_pattern * num_audios
audio_mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
mm_data={
**processor_inputs.mm_data,
**audio_mm_data,
},
)
return {
**super().get_dummy_mm_data(seq_len, mm_counts),
**audio_mm_data,
}
class MiniCPMOMultiModalProcessor(
......@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor(
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {}
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else:
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
......@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor(
]
audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
......@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor(
else:
audio_len = audios.get_audio_length(item_idx)
return self.get_audio_prompt_texts(audio_len)
return PromptUpdateDetails.select_text(
self.get_audio_prompt_texts(audio_len),
"<unk>",
)
return [
*base_updates,
......@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6):
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
......@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6):
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
)
if not isinstance(audio_features, (torch.Tensor, list)):
......@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6):
type="audio_features",
audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
......@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6):
if modality == "audios":
audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
multimodal_embeddings += tuple(audio_features)
return multimodal_embeddings
......@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems,
......@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
......@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
......@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
......@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
instead of a batched tensor.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs]
......@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
......@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(
data,
......@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(
data,
......@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return mm_limits
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(
seq_len, mm_counts)
return mm_max_tokens
def get_slice_image_placeholder(
self,
image_size: ImageSize,
......@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id=use_image_id,
)
def get_sliced_grid(
self,
image_size: ImageSize,
# For MiniCPM V/O 2.6
max_slice_nums: Optional[int] = None,
) -> Optional[tuple[int, int]]:
image_processor = self.get_image_processor()
version = self.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_sliced_grid(image_size)
if max_slice_nums is None:
max_slice_nums = image_processor.max_slice_nums
return image_processor.get_sliced_grid(
image_size,
max_slice_nums=max_slice_nums,
)
def get_num_image_tokens(
self,
image_size: ImageSize,
max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> int:
tokenizer = self.get_tokenizer()
image_placeholders = self.get_slice_image_placeholder(
image_processor = self.get_image_processor()
grid = self.get_sliced_grid(
image_size,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
)
image_token_ids = tokenizer.encode(image_placeholders,
add_special_tokens=False)
if grid is None:
ncols = nrows = 0
else:
ncols, nrows = grid
return len(image_token_ids)
return (ncols * nrows + 1) * image_processor.image_feature_size
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features()
......@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_num_image_tokens(
frame_size,
max_slice_nums=self.get_video_max_slice_num(),
use_image_id=False,
)
def get_max_video_tokens(
......@@ -482,11 +472,20 @@ _I = TypeVar("_I",
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_prompt_texts = self.info.image_pattern * num_images
video_prompt_texts = self.info.video_pattern * num_videos
return image_prompt_texts + video_prompt_texts
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
......@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
num_video_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = {
return {
"image":
self._get_dummy_images(width=image_width,
height=image_height,
......@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
] * num_videos,
}
image_prompt_texts = self.info.image_pattern * num_images
video_prompt_texts = self.info.video_pattern * num_videos
return ProcessorInputs(prompt_text=image_prompt_texts +
video_prompt_texts,
mm_data=mm_data)
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
......@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
use_image_id=False,
) * num_frames
def get_embed_is_patch(
self,
input_ids: list[int],
) -> torch.Tensor:
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
return torch.tensor(input_ids) == unk_token_id
def process_images(
self,
mm_data: Mapping[str, object],
......@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_repl_features = [
self.get_image_prompt_texts(size, idx)
for idx, size in enumerate(image_sizes)
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(image_repl_tokens)
for image_repl_tokens in image_repls_feature_tokens
]
image_inputs["embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
......@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
frame_sizes = [
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
]
num_frames = [
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
]
video_repl_features = [
self.get_video_prompt_texts(size, nframes)
for size, nframes in zip(frame_sizes, num_frames)
]
tokenizer = self.info.get_tokenizer()
video_repls_feature_tokens = [
tokenizer.encode(video_repl, add_special_tokens=False)
for video_repl in video_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(video_repl_tokens)
for video_repl_tokens in video_repls_feature_tokens
]
video_inputs["embed_is_patch"] = embed_is_patch
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
......@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_size = images.get_image_size(item_idx)
return self.get_image_prompt_texts(image_size, item_idx)
return PromptUpdateDetails.select_text(
self.get_image_prompt_texts(image_size, item_idx),
"<unk>",
)
def get_video_replacement(item_idx: int):
videos = mm_items.get_items(
......@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
frame_size = videos.get_frame_size(item_idx)
num_frames = videos.get_num_frames(item_idx)
return self.get_video_prompt_texts(frame_size, num_frames)
return PromptUpdateDetails.select_text(
self.get_video_prompt_texts(frame_size, num_frames),
"<unk>",
)
get_replacement = {
"image": get_image_replacement,
......@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item())
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of embed_is_patch for {modality=}. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
......@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return MiniCPMVImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds_flat,
embed_is_patch=embed_is_patch,
)
if not isinstance(pixel_values, (torch.Tensor, list)):
......@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
type="pixel_values",
pixel_values=pixel_values_flat,
tgt_sizes=tgt_sizes_flat,
embed_is_patch=embed_is_patch,
num_slices=num_slices_flat,
)
......@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if modality == "images":
image_input = modalities["images"]
image_features = self._process_vision_input(image_input)
multimodal_embeddings += tuple(
scatter_patch_features(
image_features,
image_input["embed_is_patch"],
))
multimodal_embeddings += tuple(image_features)
if modality == "videos":
video_input = modalities["videos"]
video_features = self._process_vision_input(video_input)
multimodal_embeddings += tuple(
scatter_patch_features(
video_features,
video_input["embed_is_patch"],
))
multimodal_embeddings += tuple(video_features)
return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
list(self.mm_token_ids),
)
return inputs_embeds
......
......@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
from .vision import get_vision_encoder_info
class Mistral3ImagePixelInputs(TypedDict):
......@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
class Mistral3PatchMerger(nn.Module):
"""
......@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
......@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Mistral3ProcessingInfo(BaseLlavaProcessingInfo):
......@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor(
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _get_mm_fields_config(
......@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
......@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return tokens
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [
PromptReplacement(
......@@ -418,8 +379,6 @@ def init_vision_tower_for_llava(
)
# TODO(mgoin): Support V1, there are issues with image batching/chunking
# that need to be resolved first.
@MULTIMODAL_REGISTRY.register_processor(
_build_mistral3_processor,
info=_build_mistral3_info,
......@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
assert self.config.vision_config.model_type == "pixtral"
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
return Mistral3ImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=flatten_bn(embed_is_patch),
)
def _process_image_input(
......@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = (image_embeds, )
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input)
return scatter_patch_features(
vision_embeddings,
image_input["embed_is_patch"],
)
return vision_embeddings
def get_input_embeddings(
self,
......@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
......
......@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -260,6 +260,8 @@ class MixtralModel(nn.Module):
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
......@@ -313,88 +315,6 @@ class MixtralModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -415,9 +335,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
......@@ -489,3 +406,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
return loader.load_weights(weights)
......@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
......@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
......
......@@ -52,16 +52,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalEncDecInputs,
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems)
MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only
......@@ -106,16 +107,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int:
vision_config = self.get_hf_config().vision_config
......@@ -141,31 +132,31 @@ class MllamaProcessingInfo(BaseProcessingInfo):
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
):
......@@ -211,6 +202,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
# }
if mm_data:
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
......@@ -227,7 +221,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
num_tokens = decode_tiles * token_per_chunk
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
] * num_tokens
mm_inputs["encoder_prompt"] = "<|image|>" * num_tokens
mm_inputs["encoder_prompt"] = image_token * num_tokens
return mm_inputs
......@@ -1188,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
super().__init__()
config: MllamaConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
......@@ -1306,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _get_and_validate_encoder_lens(
self,
encoder_seq_lens: List[int],
num_tiles: List[List[int]],
num_tokens_per_tile: int,
) -> List[int]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
attn_metadata_lens):
assert actual_len >= last_group_len
return actual_encoder_seq_lens
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
......@@ -1325,6 +1345,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
......@@ -1430,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
else:
skip_cross_attention = False
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor = kwargs.pop("num_tiles")
num_tiles = [t.tolist() for t in num_tiles_tensor]
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
for actual_len, last_group_len in zip(
actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
assert actual_len >= last_group_len
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
attn_metadata.encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
......@@ -1521,6 +1538,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
updated_params.add(name)
return updated_params
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_model")
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask:
......
......@@ -17,6 +17,7 @@
# limitations under the License.
import math
from collections.abc import Iterable, Mapping
from functools import cached_property
from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
......@@ -24,7 +25,6 @@ import torch
from torch import nn
from transformers import BatchFeature, Llama4Config, Llama4VisionConfig
from transformers.image_utils import SizeDict
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.llama4 import Llama4Processor
from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions, get_best_fit)
......@@ -33,33 +33,30 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import _initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
class Llama4ImagePatchInputs(TypedDict):
......@@ -76,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios: Union[torch.Tensor, list[torch.Tensor]]
"""
A list of aspect ratios corresponding to the number of tiles
......@@ -345,7 +338,7 @@ class Llama4VisionEncoder(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
) -> BaseModelOutput:
) -> torch.Tensor:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
......@@ -361,7 +354,7 @@ class Llama4VisionEncoder(nn.Module):
layer_outputs = encoder_layer(hidden_states)
hidden_states = layer_outputs[0]
return BaseModelOutput(last_hidden_state=hidden_states, )
return hidden_states
class Llama4UnfoldConvolution(nn.Module):
......@@ -433,7 +426,7 @@ class Llama4VisionModel(nn.Module):
def forward(
self,
images_flattened: torch.Tensor,
) -> BaseModelOutput:
) -> torch.Tensor:
# Patch embedding
hidden_state = self.patch_embedding(images_flattened)
num_tiles, num_patches, hidden_dim = hidden_state.shape
......@@ -458,8 +451,7 @@ class Llama4VisionModel(nn.Module):
hidden_state = hidden_state.view(num_tiles, -1, hidden_dim)
# Apply encoder
output = self.model(hidden_state)
hidden_state = output.last_hidden_state
hidden_state = self.model(hidden_state)
hidden_state = self.layernorm_post(hidden_state)
# Remove CLS token output
......@@ -468,10 +460,7 @@ class Llama4VisionModel(nn.Module):
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state = self.vision_adapter(hidden_state)
return BaseModelOutput(
last_hidden_state=hidden_state,
attentions=None,
)
return hidden_state
class Mllama4ProcessingInfo(BaseProcessingInfo):
......@@ -488,7 +477,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 10}
# Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.
return {"image": None}
@staticmethod
def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int:
......@@ -507,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
# image_start + local tiles * (patches + 1 x separator) +
# 1 global tile * (image x 1 + patches) + image_end
token_per_chunk = self.get_patch_per_chunk(vision_config) + 1
mm_max_tokens = (self.get_max_num_tiles() + 1) * token_per_chunk + 2
return {"image": mm_max_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
......@@ -581,33 +561,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for (r_h, r_w) in aspect_ratios
]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id = tokenizer.get_vocab()[processor.img_patch_token]
num_patches_per_chunk = self.info.get_patch_per_chunk(
vision_config)
expanded_image_tokens_list = [
processor._prompt_split_image(aspect_ratio,
num_patches_per_chunk)
for aspect_ratio in aspect_ratios
]
expanded_image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in expanded_image_tokens_list
]
embed_is_patch = [
torch.tensor(tokens) == patch_id
for tokens in expanded_image_token_ids
]
processed_outputs["aspect_ratios"] = aspect_ratios
processed_outputs["patches_per_image"] = torch.tensor(
patches_per_image)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
......@@ -622,7 +578,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
"image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"),
aspect_ratios=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
......@@ -642,12 +597,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
num_patches_per_chunk = self.info.get_patch_per_chunk(vision_config)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
img_patch_token = hf_processor.img_patch_token
def get_replacement(item_idx: int):
aspect_ratio = out_mm_kwargs["aspect_ratios"][item_idx]
return hf_processor._prompt_split_image(
repl = hf_processor._prompt_split_image(
aspect_ratio=aspect_ratio,
num_patches_per_chunk=num_patches_per_chunk)
num_patches_per_chunk=num_patches_per_chunk,
)
return PromptUpdateDetails.select_text(repl, img_patch_token)
return [
PromptReplacement(
......@@ -660,36 +620,39 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.fake_image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
(target_width,
target_height) = self.info.get_image_size_with_most_features()
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
image_token = self.info.get_hf_processor().fake_image_token
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(
Mllama4MultiModalProcessor,
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
......@@ -710,13 +673,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
architectures=["Llama4ForCausalLM"],
prefix=maybe_prefix(prefix, "language_model"))
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
self.language_model = _initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]:
......@@ -730,11 +702,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
aspect_ratios = kwargs.pop("aspect_ratios", None)
if not isinstance(aspect_ratios, (torch.Tensor, list)):
raise ValueError("Incorrect type of aspect_ratios. "
......@@ -744,7 +711,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
type="pixel_values",
flat_data=flat_pixel_values,
patches_per_image=patches_per_image,
embed_is_patch=embed_is_patch,
aspect_ratios=aspect_ratios,
)
......@@ -752,8 +718,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
flat_data = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"].tolist()
vision_embeddings_flat = self.vision_model(flat_data).last_hidden_state
return vision_embeddings_flat.split(patches_per_image, dim=0)
vision_embeddings_flat = self.vision_model(flat_data)
vision_embeddings_flat = self.multi_modal_projector(
vision_embeddings_flat)
return [
img.flatten(0, 1)
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
......@@ -761,20 +737,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
if image_input is None:
return None
# num_images x [num_chunks, num_patches, hidden_dim]
image_features = self._process_image_input(image_input)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat = [img.flatten(0, 1) for img in image_features]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat = [
is_patch.flatten(0, 1)
for is_patch in image_input["embed_is_patch"]
]
return scatter_patch_features(
image_features_flat,
embed_is_patch_flat,
)
return self._process_image_input(image_input)
def get_input_embeddings(
self,
......@@ -784,11 +747,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings)
mm_embeddings = self.multi_modal_projector(multimodal_embeddings)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, select_patch_features(mm_embeddings),
self.config.image_token_index)
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds
......@@ -800,9 +764,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if "pixel_values" in kwargs:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner,
# this condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
......@@ -857,9 +824,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_prefix = "language_model.model."
language_model_weights, other_weights = self.separate_weights(
weights, prefix=language_model_prefix)
weights, prefix="language_model.model.")
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(
language_model_weights)
......
......@@ -41,13 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
PromptInsertion, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
......@@ -56,7 +58,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
......@@ -84,14 +85,6 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
......@@ -1146,30 +1139,6 @@ class MolmoProcessorWrapper:
if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0
input_is_embed = torch.isin(
input_ids,
torch.tensor([
self.image_patch_id,
self.im_col_id,
self.im_start_id,
self.im_end_id,
]),
)
embed_ids = input_ids[input_is_embed]
embed_is_patch = embed_ids == self.image_patch_id
assert embed_is_patch.sum() == feat_is_patch.sum()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start = torch.nonzero(embed_ids == self.im_start_id)[::2, 0]
embed_end = torch.nonzero(embed_ids == self.im_end_id)[1::2, 0]
assert len(embed_start) == len(embed_end) == len(images)
embed_is_patch = [
embed_is_patch[start:end + 1]
for start, end in zip(embed_start, embed_end)
]
tilings = [
self.select_tiling(
image_width=image.size[0],
......@@ -1181,7 +1150,6 @@ class MolmoProcessorWrapper:
assert num_crops.sum() == len(feat_is_patch)
outputs["feat_is_patch"] = feat_is_patch
outputs["embed_is_patch"] = embed_is_patch
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
......@@ -1197,13 +1165,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
......@@ -1220,26 +1181,13 @@ class MolmoProcessingInfo(BaseProcessingInfo):
)
pooling_size = processor.pooling_size
base_image_input_size = processor.base_image_input_size
base_image_input_d = processor.image_patch_size
crop_patches = base_image_input_size[0] // base_image_input_d
per_row = ncols // pooling_size + 1
joint = per_row * (nrows // pooling_size) + 2
image_token_length = (crop_patches + pooling_size - 1) // pooling_size
resize = (image_token_length + 1) * image_token_length + 2
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
return resize + joint
extra = image_token_length_w * image_token_length_h
joint = ((ncols + 1) // pooling_size) * ((nrows + 1) // pooling_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return extra + joint
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
......@@ -1269,27 +1217,25 @@ class MolmoProcessingInfo(BaseProcessingInfo):
class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
......@@ -1328,7 +1274,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
......@@ -1368,8 +1313,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
joint = ([img_start_id] + joint_row *
((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint
return image_tokens
return PromptUpdateDetails.select_token_id(
extra_joint + joint,
embed_token_id=img_patch_id,
)
return [
PromptInsertion(
......@@ -1475,11 +1422,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")
embed_is_patch = kwargs.pop("embed_is_patch", None)
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_crops = kwargs.pop("num_crops", None)
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
......@@ -1491,14 +1433,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()
embed_is_patch = flatten_bn(embed_is_patch)
num_crops = flatten_bn(num_crops, concat=True)
return MolmoImageInputs(
images=images,
image_masks=image_masks,
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)
......@@ -1531,18 +1471,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
)
]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
return self._process_image_input(image_input)
def get_input_embeddings(
self,
......@@ -1556,7 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
self.img_patch_id,
)
return inputs_embeds
......
......@@ -15,12 +15,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
......@@ -57,7 +56,7 @@ class NVLMProcessor(BaseInternVLProcessor):
# when trying to find "<tile" as a subsequence of "<Image><tile"
repl = "<Image>" + features + "</Image>"
return PromptUpdateDetails(full=repl, features=repl)
return PromptUpdateDetails.select_text(repl, IMG_PAD)
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
......@@ -84,57 +83,32 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
**kwargs,
)
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
tokenizer = hf_processor.tokenizer
max_num_patches = hf_processor.max_dynamic_patch
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers = [
f"<tile_{i+1}>" for i in range(max_num_patches)
]
if hf_processor.use_thumbnail and max_num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
# so we include <tile_1> in the start_str
start_str = "<Image>" + tile_pos_identifiers.pop(0)
end_str = "</Image>"
start_token_len = len(tokenizer.encode(start_str))
end_token_len = len(tokenizer.encode(end_str))
tile_token_len = sum(
len(tokenizer.encode(identifier))
for identifier in tile_pos_identifiers)
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
return super().get_max_image_tokens() + non_image_tokens_num
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
return "<image>\n" * num_images
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
def get_dummy_processor_inputs(
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
prompt_text="<image>\n" * num_images,
mm_data=mm_data,
)
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
......@@ -177,10 +151,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
repl = hf_processor.get_image_repl(feature_size, num_patches)
return PromptUpdateDetails(
full=repl.full + "\n",
features=repl.features + "\n",
)
return PromptUpdateDetails.select_text(repl.full + "\n", IMG_PAD)
# See note in dummy data regarding why we have the extra newline
return [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment