Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a44c4f1d
Unverified
Commit
a44c4f1d
authored
Apr 29, 2025
by
Michael Goin
Committed by
GitHub
Apr 29, 2025
Browse files
Support LoRA for Mistral3 (#17428)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
88fcf00d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
4 deletions
+15
-4
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+1
-1
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+14
-3
No files found.
docs/source/models/supported_models.md
View file @
a44c4f1d
...
@@ -990,7 +990,7 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -990,7 +990,7 @@ See [this page](#generative-models) for more information on how to use generativ
*
Mistral3
*
Mistral3
*
T + I
<sup>
+
</sup>
*
T + I
<sup>
+
</sup>
*
`mistralai/Mistral-Small-3.1-24B-Instruct-2503`
, etc.
*
`mistralai/Mistral-Small-3.1-24B-Instruct-2503`
, etc.
*
*
✅︎
*
✅︎
*
✅︎
*
✅︎
*
✅︎
-
*
`MllamaForConditionalGeneration`
-
*
`MllamaForConditionalGeneration`
...
...
vllm/model_executor/models/mistral3.py
View file @
a44c4f1d
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
...
@@ -31,7 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -31,7 +32,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
)
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralHFVisionModel
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralHFVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
...
@@ -382,8 +384,8 @@ def init_vision_tower_for_llava(
...
@@ -382,8 +384,8 @@ def init_vision_tower_for_llava(
_build_mistral3_processor
,
_build_mistral3_processor
,
info
=
_build_mistral3_info
,
info
=
_build_mistral3_info
,
dummy_inputs
=
Mistral3DummyInputsBuilder
)
dummy_inputs
=
Mistral3DummyInputsBuilder
)
class
Mistral3ForConditionalGeneration
(
nn
.
Module
,
Supports
MultiModal
,
class
Mistral3ForConditionalGeneration
(
nn
.
Module
,
Supports
LoRA
,
SupportsPP
):
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
...
@@ -594,3 +596,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -594,3 +596,12 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"multi_modal_projector"
,
tower_model
=
"vision_tower"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment