Unverified Commit e6cb8e05 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

in peft finetune, only the trainable parameters need to be saved (#27825)



to reduce the storage size and also save the time of checkpoint saving while using deepspeed for training
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 7f2a8f92
...@@ -212,6 +212,10 @@ if is_accelerate_available(): ...@@ -212,6 +212,10 @@ if is_accelerate_available():
from accelerate.utils import DeepSpeedSchedulerWrapper from accelerate.utils import DeepSpeedSchedulerWrapper
def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -398,13 +402,12 @@ class Trainer: ...@@ -398,13 +402,12 @@ class Trainer:
" to `True` to avoid any unexpected behavior such as device placement mismatching." " to `True` to avoid any unexpected behavior such as device placement mismatching."
) )
_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False model, "_hf_peft_config_loaded", False
) )
# At this stage the model is already loaded # At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model: if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError( raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of" "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
...@@ -619,7 +622,7 @@ class Trainer: ...@@ -619,7 +622,7 @@ class Trainer:
""" """
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else: else:
embeddings = unwrapped_model.get_input_embeddings() embeddings = unwrapped_model.get_input_embeddings()
...@@ -640,7 +643,7 @@ class Trainer: ...@@ -640,7 +643,7 @@ class Trainer:
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else: else:
embeddings = unwrapped_model.get_input_embeddings() embeddings = unwrapped_model.get_input_embeddings()
...@@ -696,7 +699,7 @@ class Trainer: ...@@ -696,7 +699,7 @@ class Trainer:
if self._signature_columns is None: if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts. # Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel): if _is_peft_model(self.model):
model_to_inspect = self.model.get_base_model() model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward) signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys()) self._signature_columns = list(signature.parameters.keys())
...@@ -2114,7 +2117,7 @@ class Trainer: ...@@ -2114,7 +2117,7 @@ class Trainer:
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096 # Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel): elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint): if os.path.exists(resume_from_checkpoint):
...@@ -2177,7 +2180,7 @@ class Trainer: ...@@ -2177,7 +2180,7 @@ class Trainer:
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
else: else:
if is_peft_available() and isinstance(model, PeftModel): if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
...@@ -2453,7 +2456,13 @@ class Trainer: ...@@ -2453,7 +2456,13 @@ class Trainer:
elif self.is_deepspeed_enabled: elif self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True # config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir) accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled: elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt # save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
...@@ -2766,7 +2775,7 @@ class Trainer: ...@@ -2766,7 +2775,7 @@ class Trainer:
if labels is not None: if labels is not None:
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name() model_name = unwrapped_model.base_model.model._get_name()
else: else:
model_name = unwrapped_model._get_name() model_name = unwrapped_model._get_name()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment