"docs/source/de/index.mdx" did not exist on "f9a0008d2d3082a665f711b24f5314e4a8205fab"
Unverified Commit e6cb8e05 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

in peft finetune, only the trainable parameters need to be saved (#27825)



to reduce the storage size and also save the time of checkpoint saving while using deepspeed for training
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 7f2a8f92
......@@ -212,6 +212,10 @@ if is_accelerate_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)
if TYPE_CHECKING:
import optuna
......@@ -398,13 +402,12 @@ class Trainer:
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model:
if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
......@@ -619,7 +622,7 @@ class Trainer:
"""
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
......@@ -640,7 +643,7 @@ class Trainer:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
......@@ -696,7 +699,7 @@ class Trainer:
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel):
if _is_peft_model(self.model):
model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
......@@ -2114,7 +2117,7 @@ class Trainer:
self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel):
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
......@@ -2177,7 +2180,7 @@ class Trainer:
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
else:
if is_peft_available() and isinstance(model, PeftModel):
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
......@@ -2453,6 +2456,12 @@ class Trainer:
elif self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt
......@@ -2766,7 +2775,7 @@ class Trainer:
if labels is not None:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment