Unverified Commit a6609caf authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

More frozen args (#25540)

parent f61f072b
......@@ -1165,6 +1165,8 @@ class Trainer:
elif self.hp_search_backend == HPSearchBackend.WANDB:
params = trial
# Unfreeze args for hyperparameter search
delattr(self.args, "_frozen")
for key, value in params.items():
if not hasattr(self.args, key):
logger.warning(
......@@ -1176,6 +1178,7 @@ class Trainer:
# Casting value to the proper type
if old_attr is not None:
value = type(old_attr)(value)
setattr(self.args, key, value)
if self.hp_search_backend == HPSearchBackend.OPTUNA:
logger.info(f"Trial: {trial.params}")
......@@ -1194,6 +1197,9 @@ class Trainer:
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
# Re-freeze them
setattr(self.args, "_frozen", True)
self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment