Unverified Commit 3ca18d6d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Fix PEFT multi adapters support (#26407)



* fix PEFT multi adapters support

* refactor a bit

* save pretrained + BC + added tests

* Update src/transformers/integrations/peft.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* add more tests

* add suggestion

* final changes

* adapt a bit

* fixup

* Update src/transformers/integrations/peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* adapt from suggestions

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 946bac79
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import ( from ..utils import (
check_peft_version, check_peft_version,
...@@ -245,7 +245,7 @@ class PeftAdapterMixin: ...@@ -245,7 +245,7 @@ class PeftAdapterMixin:
self.set_adapter(adapter_name) self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: str) -> None: def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
""" """
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft official documentation: https://huggingface.co/docs/peft
...@@ -253,12 +253,19 @@ class PeftAdapterMixin: ...@@ -253,12 +253,19 @@ class PeftAdapterMixin:
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
Args: Args:
adapter_name (`str`): adapter_name (`Union[List[str], str]`):
The name of the adapter to set. The name of the adapter to set. Can be also a list of strings to set multiple adapters.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded: if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.") raise ValueError("No adapter loaded. Please load an adapter first.")
elif isinstance(adapter_name, list):
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
elif adapter_name not in self.peft_config: elif adapter_name not in self.peft_config:
raise ValueError( raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
...@@ -270,7 +277,11 @@ class PeftAdapterMixin: ...@@ -270,7 +277,11 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name # For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True _adapters_has_been_set = True
if not _adapters_has_been_set: if not _adapters_has_been_set:
...@@ -294,7 +305,11 @@ class PeftAdapterMixin: ...@@ -294,7 +305,11 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
module.disable_adapters = True # The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True
def enable_adapters(self) -> None: def enable_adapters(self) -> None:
""" """
...@@ -312,14 +327,22 @@ class PeftAdapterMixin: ...@@ -312,14 +327,22 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
module.disable_adapters = False # The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False
def active_adapter(self) -> str: def active_adapters(self) -> List[str]:
""" """
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft official documentation: https://huggingface.co/docs/peft
Gets the current active adapter of the model. Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of all active adapters so that users can deal with them accordingly.
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
a single string.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
...@@ -333,7 +356,21 @@ class PeftAdapterMixin: ...@@ -333,7 +356,21 @@ class PeftAdapterMixin:
for _, module in self.named_modules(): for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
return module.active_adapter active_adapters = module.active_adapter
break
# For previous PEFT versions
if isinstance(active_adapters, str):
active_adapters = [active_adapters]
return active_adapters
def active_adapter(self) -> str:
logger.warning(
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning
)
return self.active_adapters()[0]
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
""" """
......
...@@ -2006,7 +2006,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2006,7 +2006,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
peft_state_dict[f"base_model.model.{key}"] = value peft_state_dict[f"base_model.model.{key}"] = value
state_dict = peft_state_dict state_dict = peft_state_dict
current_peft_config = self.peft_config[self.active_adapter()] active_adapter = self.active_adapters()
if len(active_adapter) > 1:
raise ValueError(
"Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
)
active_adapter = active_adapter[0]
current_peft_config = self.peft_config[active_adapter]
current_peft_config.save_pretrained(save_directory) current_peft_config.save_pretrained(save_directory)
# Save the model # Save the model
......
...@@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): ...@@ -265,9 +265,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
_ = model.generate(input_ids=dummy_input) _ = model.generate(input_ids=dummy_input)
model.set_adapter("default") model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default") self.assertTrue(model.active_adapter() == "default")
model.set_adapter("adapter-2") model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2") self.assertTrue(model.active_adapter() == "adapter-2")
# Logits comparison # Logits comparison
...@@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): ...@@ -276,6 +278,23 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
) )
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)) self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")
logits_adapter_mixed = model(dummy_input)
self.assertFalse(
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
self.assertFalse(
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
)
# multi active adapter saving not supported
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@require_torch_gpu @require_torch_gpu
def test_peft_from_pretrained_kwargs(self): def test_peft_from_pretrained_kwargs(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment