Unverified Commit f021b979 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V1] Support Mistral3 in V1 (#15950)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 1cab43c2
...@@ -888,7 +888,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -888,7 +888,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
* *
* ✅︎ * ✅︎
* * ✅︎
- * `MllamaForConditionalGeneration` - * `MllamaForConditionalGeneration`
* Llama 3.2 * Llama 3.2
* T + I<sup>+</sup> * T + I<sup>+</sup>
......
...@@ -31,12 +31,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -31,12 +31,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
SupportsV0Only)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info, select_patch_features from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
class Mistral3ImagePixelInputs(TypedDict): class Mistral3ImagePixelInputs(TypedDict):
...@@ -425,7 +425,7 @@ def init_vision_tower_for_llava( ...@@ -425,7 +425,7 @@ def init_vision_tower_for_llava(
info=_build_mistral3_info, info=_build_mistral3_info,
dummy_inputs=Mistral3DummyInputsBuilder) dummy_inputs=Mistral3DummyInputsBuilder)
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsV0Only): SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
...@@ -518,7 +518,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -518,7 +518,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Mistral3ImagePixelInputs( return Mistral3ImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch, embed_is_patch=flatten_bn(embed_is_patch),
) )
def _process_image_input( def _process_image_input(
...@@ -557,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -557,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return scatter_patch_features(
vision_embeddings,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment