"...composable_kernel.git" did not exist on "c0bfcf9101c542d68e3ee2fb3b1308fd5b73f31e"
Unverified Commit bebeeee0 authored by Hieu Lam's avatar Hieu Lam Committed by GitHub
Browse files

Resolve DeepSpeed cannot resume training with PeftModel (#28746)

* fix: resolve deepspeed resume peft model issues

* chore: update something

* chore: update model instance pass into is peft model checks

* chore: remove hard code value to tests

* fix: format code
parent 65a926e8
...@@ -143,14 +143,25 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -143,14 +143,25 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
"per_device_train_batch_size", "per_device_train_batch_size",
not auto_find_batch_size, not auto_find_batch_size,
) )
self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps")
self.fill_match( self.fill_match(
"train_batch_size", train_batch_size, "train_batch_size (calculated)", not auto_find_batch_size "gradient_accumulation_steps",
args.gradient_accumulation_steps,
"gradient_accumulation_steps",
)
self.fill_match(
"train_batch_size",
train_batch_size,
"train_batch_size (calculated)",
not auto_find_batch_size,
) )
self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm") self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm")
self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate") self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate")
self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2") self.fill_match(
"optimizer.params.betas",
[args.adam_beta1, args.adam_beta2],
"adam_beta1+adam_beta2",
)
self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon") self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon")
self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay") self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay")
...@@ -225,12 +236,26 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -225,12 +236,26 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size) self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
if self.is_zero3(): if self.is_zero3():
# automatically assign the optimal config values based on model config # automatically assign the optimal config values based on model config
self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) self.fill_only(
self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size) "zero_optimization.stage3_prefetch_bucket_size",
0.9 * hidden_size * hidden_size,
)
self.fill_only(
"zero_optimization.stage3_param_persistence_threshold",
10 * hidden_size,
)
# scheduler # scheduler
self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)") self.fill_match(
self.fill_match("scheduler.params.warmup_num_steps", args.get_warmup_steps(num_training_steps), "warmup_steps") "scheduler.params.total_num_steps",
num_training_steps,
"num_training_steps (calculated)",
)
self.fill_match(
"scheduler.params.warmup_num_steps",
args.get_warmup_steps(num_training_steps),
"warmup_steps",
)
if len(self.mismatches) > 0: if len(self.mismatches) > 0:
mismatches = "\n".join(self.mismatches) mismatches = "\n".join(self.mismatches)
...@@ -387,7 +412,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False): ...@@ -387,7 +412,7 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
return optimizer, lr_scheduler return optimizer, lr_scheduler
def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
# it's possible that the user is trying to resume from model_path, which doesn't necessarily # it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the # a resume from a checkpoint and not just a local pretrained weight. So we check here if the
...@@ -400,7 +425,10 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): ...@@ -400,7 +425,10 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path):
logger.info(f"Attempting to resume from {checkpoint_path}") logger.info(f"Attempting to resume from {checkpoint_path}")
# this magically updates self.optimizer and self.lr_scheduler # this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint( load_path, _ = deepspeed_engine.load_checkpoint(
checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True checkpoint_path,
load_module_strict=load_module_strict,
load_optimizer_states=True,
load_lr_scheduler_states=True,
) )
if load_path is None: if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
......
...@@ -1727,7 +1727,9 @@ class Trainer: ...@@ -1727,7 +1727,9 @@ class Trainer:
# ckpt loading # ckpt loading
if resume_from_checkpoint is not None: if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) deepspeed_load_checkpoint(
self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
...@@ -2193,7 +2195,11 @@ class Trainer: ...@@ -2193,7 +2195,11 @@ class Trainer:
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) deepspeed_load_checkpoint(
self.model_wrapped,
self.state.best_model_checkpoint,
load_module_strict=not _is_peft_model(self.model),
)
elif self.is_fsdp_enabled: elif self.is_fsdp_enabled:
load_result = load_fsdp_model( load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator.state.fsdp_plugin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment