Commit 142c5e65 authored by Jennifer's avatar Jennifer
Browse files

Updates organization of command line flags for pl.Trainer

parent 2eda3215
...@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-5, eps: float = 1e-5,
) -> torch.optim.Adam: ) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured # Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
self.model.parameters(), self.model.parameters(),
...@@ -289,10 +284,13 @@ def main(args): ...@@ -289,10 +284,13 @@ def main(args):
if(args.seed is not None): if(args.seed is not None):
seed_everything(args.seed, workers=True) seed_everything(args.seed, workers=True)
is_low_precision = args.precision in [
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
config = model_config( config = model_config(
args.config_preset, args.config_preset,
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=is_low_precision,
) )
if args.experiment_config_json: if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f: with open(args.experiment_config_json, 'r') as f:
...@@ -432,17 +430,17 @@ def main(args): ...@@ -432,17 +430,17 @@ def main(args):
os.system(f"{sys.executable} -m pip freeze > {freeze_path}") os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}") wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer( trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
num_nodes=args.num_nodes, 'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
devices=args.gpus, trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
precision=args.precision, trainer_args.update({
max_epochs=args.max_epochs, 'default_root_dir': args.output_dir,
default_root_dir=args.output_dir, 'strategy': strategy,
strategy=strategy, 'callbacks': callbacks,
callbacks=callbacks, 'logger': loggers,
logger=loggers, })
profiler='simple', trainer = pl.Trainer(**trainer_args)
)
if (args.resume_model_weights_only): if (args.resume_model_weights_only):
ckpt_path = None ckpt_path = None
...@@ -652,32 +650,39 @@ if __name__ == "__main__": ...@@ -652,32 +650,39 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
) )
parser.add_argument("--num_nodes", type=int, default=1) parser.add_argument(
parser.add_argument("--gpus", type=int, default=None) "--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
parser.add_argument("--max_epochs", type=int, default=None) )
parser.add_argument("--precision", type=str, default="32") parser.add_argument("--mpi_plugin", action="store_true", default=False,
parser.add_argument("--log_every_n_steps", type=int, default=50) help="Whether to use MPI for parallele processing")
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
parser.add_argument("--flush_logs_every_n_steps", type=int, default=5) trainer_group = parser.add_argument_group(
parser.add_argument("--num_sanity_val_steps", type=int, default=0) 'Arguments to pass to PyTorch Lightning Trainer')
parser.add_argument("--mpi_plugin", action="store_true", default=False) trainer_group.add_argument(
# parser = pl.Trainer.add_argparse_args(parser) "--num_nodes", type=int, default=1,
)
# Disable the initial validation pass trainer_group.add_argument(
parser.set_defaults( "--precision", type=str, default='bf16',
num_sanity_val_steps=0, help='Sets precision, lower precision improves runtime performance.',
) )
trainer_group.add_argument(
# Remove some buggy/redundant arguments introduced by the Trainer "--max_epochs", type=int, default=1,
remove_arguments( )
parser, trainer_group.add_argument(
[ "--log_every_n_steps", type=int, default=25,
"--accelerator", )
"--resume_from_checkpoint", trainer_group.add_argument(
"--reload_dataloaders_every_epoch", "--flush_logs_every_n_steps", type=int, default=5,
"--reload_dataloaders_every_n_epochs", )
] trainer_group.add_argument(
) "--num_sanity_val_steps", type=int, default=0,
)
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args() args = parser.parse_args()
...@@ -692,7 +697,5 @@ if __name__ == "__main__": ...@@ -692,7 +697,5 @@ if __name__ == "__main__":
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment