Unverified Commit 42f25d60 authored by Steven Munn's avatar Steven Munn Committed by GitHub
Browse files

Skip PEFT LoRA Scaling if the scale is 1.0 (#7576)



* Skip scaling if scale is identity

* move check for weight one to scale and unscale lora

* fix code style/quality

* Empty-Commit

---------
Co-authored-by: default avatarSteven Munn <stevenjmunn@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Munn <5297082+stevenjlm@users.noreply.github.com>
parent 33c5d125
...@@ -64,9 +64,11 @@ def recurse_remove_peft_layers(model): ...@@ -64,9 +64,11 @@ def recurse_remove_peft_layers(model):
module_replaced = False module_replaced = False
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to( new_module = torch.nn.Linear(
module.weight.device module.in_features,
) module.out_features,
bias=module.bias is not None,
).to(module.weight.device)
new_module.weight = module.weight new_module.weight = module.weight
if module.bias is not None: if module.bias is not None:
new_module.bias = module.bias new_module.bias = module.bias
...@@ -110,6 +112,9 @@ def scale_lora_layers(model, weight): ...@@ -110,6 +112,9 @@ def scale_lora_layers(model, weight):
""" """
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
if weight == 1.0:
return
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
module.scale_layer(weight) module.scale_layer(weight)
...@@ -129,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None): ...@@ -129,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
""" """
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
if weight == 1.0:
return
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
if weight is not None and weight != 0: if weight is not None and weight != 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment