Unverified Commit 373080ea authored by gongwei-130's avatar gongwei-130 Committed by GitHub
Browse files

skip vision_model for lora (#10530)

parent 7f028b07
...@@ -415,6 +415,10 @@ class LoRAManager: ...@@ -415,6 +415,10 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module) replace_submodule(self.base_model, module_name, lora_module)
return lora_module return lora_module
def should_skip_lora_for_vision_model(self, module_name):
# TODO: support different vision models
return module_name.find("vision_model.model") != -1
def init_lora_modules(self): def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
...@@ -432,6 +436,10 @@ class LoRAManager: ...@@ -432,6 +436,10 @@ class LoRAManager:
) and not self.base_model.should_apply_lora(module_name): ) and not self.base_model.should_apply_lora(module_name):
continue continue
# Skip vision model
if self.should_skip_lora_for_vision_model(module_name):
continue
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.target_modules: if module_name.split(".")[-1] in self.target_modules:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment