Unverified Commit 6f143533 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

Speed up the peft lora unload (#5741)



* Update peft_utils.py

* fix bug

* make the util backwards compatible.
Co-Authored-By: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix import issue

* refactor the backward compatibilty condition

* rename the conditional variable

* address comments
Co-Authored-By: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* address comment

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent c6f90dae
...@@ -23,13 +23,36 @@ from packaging import version ...@@ -23,13 +23,36 @@ from packaging import version
from .import_utils import is_peft_available, is_torch_available from .import_utils import is_peft_available, is_torch_available
def recurse_remove_peft_layers(model): if is_torch_available():
if is_torch_available():
import torch import torch
def recurse_remove_peft_layers(model):
r""" r"""
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
""" """
from peft.tuners.tuners_utils import BaseTunerLayer
has_base_layer_pattern = False
for module in model.modules():
if isinstance(module, BaseTunerLayer):
has_base_layer_pattern = hasattr(module, "base_layer")
break
if has_base_layer_pattern:
from peft.utils import _get_submodules
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
try:
parent, target, target_name = _get_submodules(model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
setattr(parent, target_name, target.get_base_layer())
else:
# This is for backwards compatibility with PEFT <= 0.6.2.
# TODO can be removed once that PEFT version is no longer supported.
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
for name, module in model.named_children(): for name, module in model.named_children():
...@@ -71,7 +94,6 @@ def recurse_remove_peft_layers(model): ...@@ -71,7 +94,6 @@ def recurse_remove_peft_layers(model):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment