"docs/source/vscode:/vscode.git/clone" did not exist on "4c2b4c4c3c79b8d1250efe95daf3682fbcdbaa39"
Unverified Commit e6cb8e05 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

in peft finetune, only the trainable parameters need to be saved (#27825)



to reduce the storage size and also save the time of checkpoint saving while using deepspeed for training
Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
parent 7f2a8f92
...@@ -212,6 +212,10 @@ if is_accelerate_available(): ...@@ -212,6 +212,10 @@ if is_accelerate_available():
from accelerate.utils import DeepSpeedSchedulerWrapper from accelerate.utils import DeepSpeedSchedulerWrapper
def _is_peft_model(model):
return is_peft_available() and isinstance(model, PeftModel)
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -398,13 +402,12 @@ class Trainer: ...@@ -398,13 +402,12 @@ class Trainer:
" to `True` to avoid any unexpected behavior such as device placement mismatching." " to `True` to avoid any unexpected behavior such as device placement mismatching."
) )
_is_peft_model = is_peft_available() and isinstance(model, PeftModel)
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
model, "_hf_peft_config_loaded", False model, "_hf_peft_config_loaded", False
) )
# At this stage the model is already loaded # At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model: if _is_quantized_and_base_model and not _is_peft_model(model):
raise ValueError( raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of" "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
...@@ -619,7 +622,7 @@ class Trainer: ...@@ -619,7 +622,7 @@ class Trainer:
""" """
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else: else:
embeddings = unwrapped_model.get_input_embeddings() embeddings = unwrapped_model.get_input_embeddings()
...@@ -640,7 +643,7 @@ class Trainer: ...@@ -640,7 +643,7 @@ class Trainer:
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if _is_peft_model(unwrapped_model):
embeddings = unwrapped_model.base_model.model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else: else:
embeddings = unwrapped_model.get_input_embeddings() embeddings = unwrapped_model.get_input_embeddings()
...@@ -696,7 +699,7 @@ class Trainer: ...@@ -696,7 +699,7 @@ class Trainer:
if self._signature_columns is None: if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts. # Inspect model forward signature to keep only the arguments it accepts.
model_to_inspect = self.model model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel): if _is_peft_model(self.model):
model_to_inspect = self.model.get_base_model() model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward) signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys()) self._signature_columns = list(signature.parameters.keys())
...@@ -2114,7 +2117,7 @@ class Trainer: ...@@ -2114,7 +2117,7 @@ class Trainer:
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096 # Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel): elif _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint): if os.path.exists(resume_from_checkpoint):
...@@ -2177,7 +2180,7 @@ class Trainer: ...@@ -2177,7 +2180,7 @@ class Trainer:
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
else: else:
if is_peft_available() and isinstance(model, PeftModel): if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
...@@ -2453,7 +2456,13 @@ class Trainer: ...@@ -2453,7 +2456,13 @@ class Trainer:
elif self.is_deepspeed_enabled: elif self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True # config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir) accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled: elif self.is_fsdp_enabled:
# save fsdp specific ckpt for resuming from ckpt # save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
...@@ -2766,7 +2775,7 @@ class Trainer: ...@@ -2766,7 +2775,7 @@ class Trainer:
if labels is not None: if labels is not None:
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name() model_name = unwrapped_model.base_model.model._get_name()
else: else:
model_name = unwrapped_model._get_name() model_name = unwrapped_model._get_name()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment