Unverified Commit 067eab1b authored by Thanh Le's avatar Thanh Le Committed by GitHub
Browse files

Faster set_adapters (#10777)



* Update peft_utils.py

* Update peft_utils.py

* Update peft_utils.py

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 57ac6738
...@@ -257,26 +257,18 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): ...@@ -257,26 +257,18 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
return block_weight return block_weight
# iterate over each adapter, make it active and set the corresponding scaling weight for module_name, module in model.named_modules():
for adapter_name, weight in zip(adapter_names, weights):
for module_name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.set_scale(adapter_name, get_module_weight(weight, module_name))
# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions # For backward compatibility with previous PEFT versions, set multiple active adapters
if hasattr(module, "set_adapter"): if hasattr(module, "set_adapter"):
module.set_adapter(adapter_names) module.set_adapter(adapter_names)
else: else:
module.active_adapter = adapter_names module.active_adapter = adapter_names
# Set the scaling weight for each adapter for this module
for adapter_name, weight in zip(adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def check_peft_version(min_version: str) -> None: def check_peft_version(min_version: str) -> None:
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment