Unverified Commit afc35ccc authored by Chenxi Li's avatar Chenxi Li Committed by GitHub
Browse files

Fix LoRA support for multimodal models (VLMs) by implementing a consistent...

Fix LoRA support for multimodal models (VLMs) by implementing a consistent pattern for skipping vision components (#11261)
parent a57f0e3d
...@@ -418,10 +418,6 @@ class LoRAManager: ...@@ -418,10 +418,6 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module) replace_submodule(self.base_model, module_name, lora_module)
return lora_module return lora_module
def should_skip_lora_for_vision_model(self, module_name):
# TODO: support different vision models
return module_name.find("vision_model.model") != -1
def init_lora_modules(self): def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
...@@ -439,10 +435,6 @@ class LoRAManager: ...@@ -439,10 +435,6 @@ class LoRAManager:
) and not self.base_model.should_apply_lora(module_name): ) and not self.base_model.should_apply_lora(module_name):
continue continue
# Skip vision model
if self.should_skip_lora_for_vision_model(module_name):
continue
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.target_modules: if module_name.split(".")[-1] in self.target_modules:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
import logging import logging
import re
from functools import lru_cache from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
...@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
supports_lora = True supports_lora = True
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__( def __init__(
self, self,
...@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
# For LoRA compatibility: expose text_config attributes at top level
# This allows LoRA code to work without special multimodal handling
if not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.text_config.num_hidden_layers
if not hasattr(config, "hidden_size"):
config.hidden_size = config.text_config.hidden_size
self.vision_tower = SiglipVisionModel( self.vision_tower = SiglipVisionModel(
config=config.vision_config, config=config.vision_config,
quant_config=quant_config, quant_config=quant_config,
...@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return hs return hs
def should_apply_lora(self, module_name: str) -> bool:
"""Skip vision tower and multi_modal_projector for LoRA."""
return bool(self.lora_pattern.match(module_name))
def tie_weights(self): def tie_weights(self):
return self.language_model.tie_weights() return self.language_model.tie_weights()
......
...@@ -2,6 +2,7 @@ import json as json_lib ...@@ -2,6 +2,7 @@ import json as json_lib
import logging import logging
import math import math
import os import os
import re
from collections.abc import Iterable from collections.abc import Iterable
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
...@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__( def __init__(
self, self,
config: Llama4Config, config: Llama4Config,
...@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
return projected_vision_flat return projected_vision_flat
def should_apply_lora(self, module_name: str) -> bool:
"""Skip vision model and multi_modal_projector for LoRA."""
return bool(self.lora_pattern.match(module_name))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment