Unverified Commit 61abd50b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove use of deprected method in Trainer HP search (#8996)

parent 7e1d709e
...@@ -543,7 +543,7 @@ class Trainer: ...@@ -543,7 +543,7 @@ class Trainer:
self.args.output_dir = checkpoint_dir self.args.output_dir = checkpoint_dir
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir) self.save_model(output_dir)
if self.is_world_master(): if self.is_world_process_zero():
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json")) self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment