"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6eb51450fa2a440a45e02b29f01e4f2aa4f70a4d"
Unverified Commit 2d1602ae authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

PEFT / Trainer: Make use of `model.active_adapters()` instead of deprecated...

PEFT / Trainer: Make use of `model.active_adapters()` instead of deprecated `model.active_adapter` whenever possible (#30738)

* Update trainer.py

* Update src/transformers/trainer.py

* Update src/transformers/trainer.py

* Update src/transformers/trainer.py

* style

* Update src/transformers/trainer.py

* Update src/transformers/trainer.py
parent 1c52cb7b
...@@ -2526,16 +2526,27 @@ class Trainer: ...@@ -2526,16 +2526,27 @@ class Trainer:
# Load adapters following PR # 24096 # Load adapters following PR # 24096
elif _is_peft_model(model): elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): # TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
if os.path.exists(resume_from_checkpoint): if os.path.exists(resume_from_checkpoint):
if adapter_subdirs: # For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapters = model.active_adapters()
if len(active_adapters) > 1:
logger.warning("Multiple active adapters detected will only consider the first adapter")
active_adapter = active_adapters[0]
else:
active_adapter = model.active_adapter active_adapter = model.active_adapter
if adapter_subdirs:
for subdir_name in adapter_subdirs: for subdir_name in adapter_subdirs:
peft_id = os.path.join(resume_from_checkpoint, subdir_name) peft_id = os.path.join(resume_from_checkpoint, subdir_name)
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter)) model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
model.set_adapter(active_adapter) model.set_adapter(active_adapter)
else: else:
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
else: else:
logger.warning( logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, " "The intermediate checkpoints of PEFT may not be saved correctly, "
...@@ -2609,9 +2620,20 @@ class Trainer: ...@@ -2609,9 +2620,20 @@ class Trainer:
else: else:
if _is_peft_model(model): if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): # TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
# For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapter = model.active_adapters()[0]
if len(model.active_adapters()) > 1:
logger.warning("Detected multiple active adapters, will only consider the first one")
else:
active_adapter = model.active_adapter
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) model.load_adapter(self.state.best_model_checkpoint, active_adapter)
# Load_adapter has no return value present, modify it when appropriate. # Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys from torch.nn.modules.module import _IncompatibleKeys
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment