Unverified Commit 989ecd20 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Gemma3ForConditionalGeneration supports LoRA (#14797)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 54cc46f3
...@@ -12,6 +12,7 @@ from vllm.config import VllmConfig ...@@ -12,6 +12,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
...@@ -23,7 +24,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -23,7 +24,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -371,8 +373,8 @@ class Gemma3MultiModalProjector(nn.Module): ...@@ -371,8 +373,8 @@ class Gemma3MultiModalProjector(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo, info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder) dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsPP): SupportsLoRA):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -614,3 +616,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -614,3 +616,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment