Unverified Commit e6cb8e05 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

in peft finetune, only the trainable parameters need to be saved (#27825)



to reduce the storage size and also save the time of checkpoint saving while using deepspeed for training
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 7f2a8f92
......@@ -212,6 +212,10 @@ if is_accelerate_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)
if TYPE_CHECKING:
import optuna
......@@ -398,13 +402,12 @@ class Trainer:
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)
_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False
)
# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model:
if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
......@@ -619,7 +622,7 @@ class Trainer:
"""
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
......@@ -640,7 +643,7 @@ class Trainer:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()
......@@ -696,7 +699,7 @@ class Trainer:
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel):
if _is_peft_model(self.model):
model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
......@@ -2114,7 +2117,7 @@ class Trainer:
self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel):
elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
......@@ -2177,7 +2180,7 @@ class Trainer:
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
else:
if is_peft_available() and isinstance(model, PeftModel):
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
......@@ -2453,7 +2456,13 @@ class Trainer:
elif self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
......@@ -2766,7 +2775,7 @@ class Trainer:
if labels is not None:
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment