Unverified Commit f021b979 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[V1] Support Mistral3 in V1 (#15950)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 1cab43c2
......@@ -888,7 +888,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
*
* ✅︎
*
* ✅︎
- * `MllamaForConditionalGeneration`
* Llama 3.2
* T + I<sup>+</sup>
......
......@@ -31,12 +31,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsV0Only)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info, select_patch_features
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
class Mistral3ImagePixelInputs(TypedDict):
......@@ -425,7 +425,7 @@ def init_vision_tower_for_llava(
info=_build_mistral3_info,
dummy_inputs=Mistral3DummyInputsBuilder)
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsV0Only):
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
......@@ -518,7 +518,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Mistral3ImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
embed_is_patch=flatten_bn(embed_is_patch),
)
def _process_image_input(
......@@ -557,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
return scatter_patch_features(
vision_embeddings,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment