Commit 523adaf4 authored by Jennifer's avatar Jennifer
Browse files

adds reload_dataloaders_every_n_epochs flag

parent 577219c1
......@@ -416,8 +416,7 @@ def main(args):
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
# Raw dump of all args from pl.Trainer constructor
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps']
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
trainer_args.update({
'default_root_dir': args.output_dir,
......@@ -658,6 +657,9 @@ if __name__ == "__main__":
trainer_group.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
args = parser.parse_args()
......@@ -673,7 +675,4 @@ if __name__ == "__main__":
raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment