Commit 523adaf4 authored by Jennifer's avatar Jennifer
Browse files

adds reload_dataloaders_every_n_epochs flag

parent 577219c1
...@@ -416,8 +416,7 @@ def main(args): ...@@ -416,8 +416,7 @@ def main(args):
os.system(f"{sys.executable} -m pip freeze > {freeze_path}") os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}") wdb_logger.experiment.save(f"{freeze_path}")
# Raw dump of all args from pl.Trainer constructor trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps']
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
trainer_args.update({ trainer_args.update({
'default_root_dir': args.output_dir, 'default_root_dir': args.output_dir,
...@@ -658,6 +657,9 @@ if __name__ == "__main__": ...@@ -658,6 +657,9 @@ if __name__ == "__main__":
trainer_group.add_argument( trainer_group.add_argument(
"--num_sanity_val_steps", type=int, default=0, "--num_sanity_val_steps", type=int, default=0,
) )
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
args = parser.parse_args() args = parser.parse_args()
...@@ -673,7 +675,4 @@ if __name__ == "__main__": ...@@ -673,7 +675,4 @@ if __name__ == "__main__":
raise ValueError( raise ValueError(
"Choose between loading pretrained Jax-weights and a checkpoint-path") "Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment