Unverified Commit a2b053dc authored by Navanit Dubey's avatar Navanit Dubey Committed by GitHub
Browse files

feat(model): Add BitsAndBytes quantization support for Qwen3-Omni-MoE (#29896)


Signed-off-by: default avatarnavanit-git <navanitdubey@gmail.com>
parent 1d93f116
...@@ -62,6 +62,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -62,6 +62,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
...@@ -1137,6 +1138,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1137,6 +1138,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
} }
) )
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"): if modality.startswith("image"):
...@@ -1763,3 +1776,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1763,3 +1776,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
mrope_position_delta = llm_positions.max() + 1 - seq_len mrope_position_delta = llm_positions.max() + 1 - seq_len
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="visual.merger",
tower_model=["visual.", "audio_tower."],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment