"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5ca356a464e98e065488205f3fcf9247f56c3832"
Unverified Commit 38e96324 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] introducing `adapter_kwargs` for loading adapters from different Hub...


[`PEFT`] introducing `adapter_kwargs` for loading adapters from different Hub location (`subfolder`, `revision`) than the base model (#26270)

* make use of adapter_revision

* v1 adapter kwargs

* fix CI

* fix CI

* fix CI

* fixup

* add BC

* Update src/transformers/integrations/peft.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* change it to error

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* fixup

* change

* Update src/transformers/integrations/peft.py

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 52e2c13d
...@@ -77,6 +77,7 @@ class PeftAdapterMixin: ...@@ -77,6 +77,7 @@ class PeftAdapterMixin:
offload_index: Optional[int] = None, offload_index: Optional[int] = None,
peft_config: Dict[str, Any] = None, peft_config: Dict[str, Any] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None, adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
...@@ -128,10 +129,15 @@ class PeftAdapterMixin: ...@@ -128,10 +129,15 @@ class PeftAdapterMixin:
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*): adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts dicts
adapter_kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
`find_adapter_config_file` method.
""" """
check_peft_version(min_version=MIN_PEFT_VERSION) check_peft_version(min_version=MIN_PEFT_VERSION)
adapter_name = adapter_name if adapter_name is not None else "default" adapter_name = adapter_name if adapter_name is not None else "default"
if adapter_kwargs is None:
adapter_kwargs = {}
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
from peft.utils import set_peft_model_state_dict from peft.utils import set_peft_model_state_dict
...@@ -144,11 +150,20 @@ class PeftAdapterMixin: ...@@ -144,11 +150,20 @@ class PeftAdapterMixin:
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
) )
# We keep `revision` in the signature for backward compatibility
if revision is not None and "revision" not in adapter_kwargs:
adapter_kwargs["revision"] = revision
elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]:
logger.error(
"You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. "
"The one in `adapter_kwargs` will be used."
)
if peft_config is None: if peft_config is None:
adapter_config_file = find_adapter_config_file( adapter_config_file = find_adapter_config_file(
peft_model_id, peft_model_id,
revision=revision,
token=token, token=token,
**adapter_kwargs,
) )
if adapter_config_file is None: if adapter_config_file is None:
...@@ -159,8 +174,8 @@ class PeftAdapterMixin: ...@@ -159,8 +174,8 @@ class PeftAdapterMixin:
peft_config = PeftConfig.from_pretrained( peft_config = PeftConfig.from_pretrained(
peft_model_id, peft_model_id,
revision=revision,
use_auth_token=token, use_auth_token=token,
**adapter_kwargs,
) )
# Create and add fresh new adapters into the model. # Create and add fresh new adapters into the model.
...@@ -170,7 +185,7 @@ class PeftAdapterMixin: ...@@ -170,7 +185,7 @@ class PeftAdapterMixin:
self._hf_peft_config_loaded = True self._hf_peft_config_loaded = True
if peft_model_id is not None: if peft_model_id is not None:
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) adapter_state_dict = load_peft_weights(peft_model_id, use_auth_token=token, **adapter_kwargs)
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {} processed_adapter_state_dict = {}
......
...@@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -623,6 +623,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
# Not relevant for Flax Models
_ = kwargs.pop("adapter_kwargs", None)
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
......
...@@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2645,6 +2645,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None) tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
# Not relevant for TF models
_ = kwargs.pop("adapter_kwargs", None)
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
......
...@@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2463,7 +2463,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
subfolder = kwargs.pop("subfolder", "") subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None) adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default") adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
...@@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2516,6 +2516,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
commit_hash = getattr(config, "_commit_hash", None) commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available(): if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None: if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file( _adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2525,14 +2527,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
token=token, token=token,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash, _commit_hash=commit_hash,
**adapter_kwargs,
) )
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f: with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path _adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
else:
_adapter_model_path = None
# change device_map into a map if we passed an int, a str or a torch.device # change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device): if isinstance(device_map, torch.device):
...@@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3371,8 +3374,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.load_adapter( model.load_adapter(
_adapter_model_path, _adapter_model_path,
adapter_name=adapter_name, adapter_name=adapter_name,
revision=revision,
token=token, token=token,
adapter_kwargs=adapter_kwargs,
) )
if output_loading_info: if output_loading_info:
......
...@@ -469,6 +469,7 @@ class _BaseAutoModelClass: ...@@ -469,6 +469,7 @@ class _BaseAutoModelClass:
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
code_revision = kwargs.pop("code_revision", None) code_revision = kwargs.pop("code_revision", None)
commit_hash = kwargs.pop("_commit_hash", None) commit_hash = kwargs.pop("_commit_hash", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", None)
revision = hub_kwargs.pop("revision", None) revision = hub_kwargs.pop("revision", None)
hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code) hub_kwargs["revision"] = sanitize_code_revision(pretrained_model_name_or_path, revision, trust_remote_code)
...@@ -503,15 +504,18 @@ class _BaseAutoModelClass: ...@@ -503,15 +504,18 @@ class _BaseAutoModelClass:
commit_hash = getattr(config, "_commit_hash", None) commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available(): if is_peft_available():
if adapter_kwargs is None:
adapter_kwargs = {}
maybe_adapter_path = find_adapter_config_file( maybe_adapter_path = find_adapter_config_file(
pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs
) )
if maybe_adapter_path is not None: if maybe_adapter_path is not None:
with open(maybe_adapter_path, "r", encoding="utf-8") as f: with open(maybe_adapter_path, "r", encoding="utf-8") as f:
adapter_config = json.load(f) adapter_config = json.load(f)
kwargs["_adapter_model_path"] = pretrained_model_name_or_path adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
...@@ -545,6 +549,10 @@ class _BaseAutoModelClass: ...@@ -545,6 +549,10 @@ class _BaseAutoModelClass:
trust_remote_code = resolve_trust_remote_code( trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
) )
# Set the adapter kwargs
kwargs["adapter_kwargs"] = adapter_kwargs
if has_remote_code and trust_remote_code: if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module( model_class = get_class_from_dynamic_module(
......
...@@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): ...@@ -351,3 +351,30 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
# dummy generation # dummy generation
_ = model.generate(input_ids=dummy_input) _ = model.generate(input_ids=dummy_input)
def test_peft_from_pretrained_hub_kwargs(self):
"""
Tests different combinations of PEFT model + from_pretrained + hub kwargs
"""
peft_model_id = "peft-internal-testing/tiny-opt-lora-revision"
# This should not work
with self.assertRaises(OSError):
_ = AutoModelForCausalLM.from_pretrained(peft_model_id)
adapter_kwargs = {"revision": "test"}
# This should work
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
adapter_kwargs = {"revision": "main", "subfolder": "test_subfolder"}
model = AutoModelForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
model = OPTForCausalLM.from_pretrained(peft_model_id, adapter_kwargs=adapter_kwargs)
self.assertTrue(self._check_lora_correctly_converted(model))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment