Unverified Commit a44c4f1d authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Support LoRA for Mistral3 (#17428)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 88fcf00d
......@@ -990,7 +990,7 @@ See [this page](#generative-models) for more information on how to use generativ
* Mistral3
* T + I<sup>+</sup>
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
*
* ✅︎
* ✅︎
* ✅︎
- * `MllamaForConditionalGeneration`
......
......@@ -18,6 +18,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
......@@ -31,7 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
......@@ -382,8 +384,8 @@ def init_vision_tower_for_llava(
_build_mistral3_processor,
info=_build_mistral3_info,
dummy_inputs=Mistral3DummyInputsBuilder)
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
......@@ -594,3 +596,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment