Unverified Commit 0a55d9f7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PEFT] Allow PEFT model dict to be loaded (#25721)



* Allow PEFT model dict to be loaded

* make style

* make style

* Apply suggestions from code review

* address comments

* fixup

* final change

* added tests

* fix test

* better logic for handling if adapter has been loaded

* Update tests/peft_integration/test_peft_integration.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 8b134714
...@@ -12,13 +12,14 @@ ...@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from ..utils import ( from ..utils import (
check_peft_version, check_peft_version,
find_adapter_config_file, find_adapter_config_file,
is_accelerate_available, is_accelerate_available,
is_peft_available, is_peft_available,
is_torch_available,
logging, logging,
) )
...@@ -30,6 +31,11 @@ if is_accelerate_available(): ...@@ -30,6 +31,11 @@ if is_accelerate_available():
# Minimum PEFT version supported for the integration # Minimum PEFT version supported for the integration
MIN_PEFT_VERSION = "0.5.0" MIN_PEFT_VERSION = "0.5.0"
if TYPE_CHECKING:
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -61,7 +67,7 @@ class PeftAdapterMixin: ...@@ -61,7 +67,7 @@ class PeftAdapterMixin:
def load_adapter( def load_adapter(
self, self,
peft_model_id: str, peft_model_id: Optional[str] = None,
adapter_name: Optional[str] = None, adapter_name: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
token: Optional[str] = None, token: Optional[str] = None,
...@@ -69,6 +75,8 @@ class PeftAdapterMixin: ...@@ -69,6 +75,8 @@ class PeftAdapterMixin:
max_memory: Optional[str] = None, max_memory: Optional[str] = None,
offload_folder: Optional[str] = None, offload_folder: Optional[str] = None,
offload_index: Optional[int] = None, offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
) -> None: ) -> None:
""" """
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
...@@ -77,7 +85,7 @@ class PeftAdapterMixin: ...@@ -77,7 +85,7 @@ class PeftAdapterMixin:
Requires peft as a backend to load the adapter weights. Requires peft as a backend to load the adapter weights.
Args: Args:
peft_model_id (`str`): peft_model_id (`str`, *optional*):
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
and adapter weights. and adapter weights.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
...@@ -114,6 +122,12 @@ class PeftAdapterMixin: ...@@ -114,6 +122,12 @@ class PeftAdapterMixin:
If the `device_map` contains any value `"disk"`, the folder where we will offload weights. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_index (`int`, `optional`): offload_index (`int`, `optional`):
`offload_index` argument to be passed to `accelerate.dispatch_model` method. `offload_index` argument to be passed to `accelerate.dispatch_model` method.
peft_config (`Dict[str, Any]`, *optional*):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods. This argument is used in case users directly pass PEFT state dicts
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
...@@ -122,11 +136,15 @@ class PeftAdapterMixin: ...@@ -122,11 +136,15 @@ class PeftAdapterMixin:
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict from peft.utils import set_peft_model_state_dict
if not self._hf_peft_config_loaded: if self._hf_peft_config_loaded and adapter_name in self.peft_config:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
raise ValueError(
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
)
if peft_config is None:
adapter_config_file = find_adapter_config_file( adapter_config_file = find_adapter_config_file(
peft_model_id, peft_model_id,
revision=revision, revision=revision,
...@@ -139,15 +157,19 @@ class PeftAdapterMixin: ...@@ -139,15 +157,19 @@ class PeftAdapterMixin:
"adapter model." "adapter model."
) )
loaded_peft_config = PeftConfig.from_pretrained( peft_config = PeftConfig.from_pretrained(
peft_model_id, peft_model_id,
revision=revision, revision=revision,
use_auth_token=token, use_auth_token=token,
) )
# Create and add fresh new adapters into the model. # Create and add fresh new adapters into the model.
inject_adapter_in_model(loaded_peft_config, self, adapter_name) inject_adapter_in_model(peft_config, self, adapter_name)
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
......
...@@ -16,6 +16,8 @@ import os ...@@ -16,6 +16,8 @@ import os
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, OPTForCausalLM from transformers import AutoModelForCausalLM, OPTForCausalLM
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import is_torch_available from transformers.utils import is_torch_available
...@@ -300,3 +302,33 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): ...@@ -300,3 +302,33 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
for model_id in self.peft_test_model_ids: for model_id in self.peft_test_model_ids:
pipe = pipeline("text-generation", model_id) pipe = pipeline("text-generation", model_id)
_ = pipe("Hello") _ = pipe("Hello")
def test_peft_add_adapter_with_state_dict(self):
"""
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
add_adapter works as expected with a state_dict being passed.
"""
from peft import LoraConfig
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig(init_lora_weights=False)
with self.assertRaises(ValueError):
model.load_adapter(peft_model_id=None)
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
dummy_state_dict = torch.load(state_dict_path)
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
with self.assertRaises(ValueError):
model.load_adapter(model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=None))
self.assertTrue(self._check_lora_correctly_converted(model))
# dummy generation
_ = model.generate(input_ids=dummy_input)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment