"examples/vscode:/vscode.git/clone" did not exist on "13deb95a405bbd1037ad233c692d7fd1de9d31e3"
Unverified Commit 38e96324 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] introducing `adapter_kwargs` for loading adapters from different Hub...


[`PEFT`] introducing `adapter_kwargs` for loading adapters from different Hub location (`subfolder`, `revision`) than the base model (#26270)

* make use of adapter_revision

* v1 adapter kwargs

* fix CI

* fix CI

* fix CI

* fixup

* add BC

* Update src/transformers/integrations/peft.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* change it to error

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* fixup

* change

* Update src/transformers/integrations/peft.py

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 52e2c13d
......@@ -77,6 +77,7 @@ class PeftAdapterMixin:
offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
......@@ -128,10 +129,15 @@ class PeftAdapterMixin:
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
adapter_name = adapter_name if adapter_name is not None else "default"
if adapter_kwargs is None:
adapter_kwargs = {}
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict
......@@ -144,11 +150,20 @@ class PeftAdapterMixin:
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
)
# We keep `revision` in the signature for backward compatibility
if revision is not None and "revision" not in adapter_kwargs:
adapter_kwargs["revision"] = revision
elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]:
logger.error(
"You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
"The one in `adapter_kwargs` will be used."
)
if peft_config is None:
adapter_config_file = find_adapter_config_file(
peft_model_id,
revision=revision,
token=token,
**adapter_kwargs,
)
if adapter_config_file is None:
......@@ -159,8 +174,8 @@ class PeftAdapterMixin:
peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=token,
**adapter_kwargs,
)
# Create and add fresh new adapters into the model.
......@@ -170,7 +185,7 @@ class PeftAdapterMixin:
self._hf_peft_config_loaded = True
if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
adapter_state_dict = load_peft_weights(peft_model_id, use_auth_token=token, **adapter_kwargs)
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
......
......@@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
# Not relevant for Flax Models
_ = kwargs.pop("adapter_kwargs", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
......
......@@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
commit_hash = kwargs.pop("_commit_hash", None)
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
# Not relevant for TF models
_ = kwargs.pop("adapter_kwargs", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
......
......@@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
......@@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
......@@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
else:
_adapter_model_path = None
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
......@@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.load_adapter(
_adapter_model_path,
adapter_name=adapter_name,
revision=revision,
token=token,
adapter_kwargs=adapter_kwargs,
)
if output_loading_info:
......
......@@ -469,6 +469,7 @@ class _BaseAutoModelClass:
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
code_revision = kwargs.pop("code_revision", None)
commit_hash = kwargs.pop("_commit_hash", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
revision = hub_kwargs.pop("revision", None)
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
......@@ -503,15 +504,18 @@ class _BaseAutoModelClass:
commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available():
if adapter_kwargs is None:
adapter_kwargs = {}
maybe_adapter_path = find_adapter_config_file(
pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs
pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
)
if maybe_adapter_path is not None:
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
adapter_config = json.load(f)
kwargs["_adapter_model_path"] = pretrained_model_name_or_path
adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
if not isinstance(config, PretrainedConfig):
......@@ -545,6 +549,10 @@ class _BaseAutoModelClass:
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
# Set the adapter kwargs
kwargs["adapter_kwargs"] = adapter_kwargs
if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module(
......
......@@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
# dummy generation
_ = model.generate(input_ids=dummy_input)
def test_peft_from_pretrained_hub_kwargs(self):
"""
Tests different combinations of PEFT model + from_pretrained + hub kwargs
"""
peft_model_id = "peft-internal-testing/tiny-opt-lora-revision"
# This should not work
with self.assertRaises(OSError):
_ = AutoModelForCausalLM.from_pretrained(peft_model_id)
adapter_kwargs = {"revision": "test"}
# This should work
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"}
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment