Unverified Commit afc35ccc authored by Chenxi Li's avatar Chenxi Li Committed by GitHub
Browse files

Fix LoRA support for multimodal models (VLMs) by implementing a consistent...

Fix LoRA support for multimodal models (VLMs) by implementing a consistent pattern for skipping vision components (#11261)
parent a57f0e3d
......@@ -418,10 +418,6 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def should_skip_lora_for_vision_model(self, module_name):
# TODO: support different vision models
return module_name.find("vision_model.model") != -1
def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
......@@ -439,10 +435,6 @@ class LoRAManager:
) and not self.base_model.should_apply_lora(module_name):
continue
# Skip vision model
if self.should_skip_lora_for_vision_model(module_name):
continue
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.target_modules:
layer_id = get_layer_id(module_name)
......
......@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import logging
import re
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
......@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
embedding_modules = {}
embedding_padding_modules = []
supports_lora = True
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__(
self,
......@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
self.config = config
self.quant_config = quant_config
# For LoRA compatibility: expose text_config attributes at top level
# This allows LoRA code to work without special multimodal handling
if not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.text_config.num_hidden_layers
if not hasattr(config, "hidden_size"):
config.hidden_size = config.text_config.hidden_size
self.vision_tower = SiglipVisionModel(
config=config.vision_config,
quant_config=quant_config,
......@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return hs
def should_apply_lora(self, module_name: str) -> bool:
"""Skip vision tower and multi_modal_projector for LoRA."""
return bool(self.lora_pattern.match(module_name))
def tie_weights(self):
return self.language_model.tie_weights()
......
......@@ -2,6 +2,7 @@ import json as json_lib
import logging
import math
import os
import re
from collections.abc import Iterable
from typing import List, Optional, Set, Tuple
......@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"],
}
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__(
self,
config: Llama4Config,
......@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
return projected_vision_flat
def should_apply_lora(self, module_name: str) -> bool:
"""Skip vision model and multi_modal_projector for LoRA."""
return bool(self.lora_pattern.match(module_name))
def forward(
self,
input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment