Unverified Commit 2614adf9 authored by Antonin Vidon's avatar Antonin Vidon Committed by GitHub
Browse files

[Fix] Skip visual layers when applying LoRA to Qwen2VL modules (#11519)

parent fdd7c69d
......@@ -28,7 +28,6 @@ from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import Qwen2VLConfig
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
......@@ -514,6 +513,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
def get_input_embeddings(self):
return self.model.embed_tokens
def should_apply_lora(self, module_name: str) -> bool:
# skip visual tower
return not module_name.startswith("visual")
def forward(
self,
input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment