Commit 142c5e65 authored by Jennifer's avatar Jennifer
Browse files

Updates organization of command line flags for pl.Trainer

parent 2eda3215
......@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate: float = 1e-3,
eps: float = 1e-5,
) -> torch.optim.Adam:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam(
self.model.parameters(),
......@@ -289,10 +284,13 @@ def main(args):
if(args.seed is not None):
seed_everything(args.seed, workers=True)
is_low_precision = args.precision in [
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
config = model_config(
args.config_preset,
train=True,
low_prec=(str(args.precision) == "16")
low_prec=is_low_precision,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
......@@ -432,17 +430,17 @@ def main(args):
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer = pl.Trainer(
num_nodes=args.num_nodes,
devices=args.gpus,
precision=args.precision,
max_epochs=args.max_epochs,
default_root_dir=args.output_dir,
strategy=strategy,
callbacks=callbacks,
logger=loggers,
profiler='simple',
)
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
trainer_args.update({
'default_root_dir': args.output_dir,
'strategy': strategy,
'callbacks': callbacks,
'logger': loggers,
})
trainer = pl.Trainer(**trainer_args)
if (args.resume_model_weights_only):
ckpt_path = None
......@@ -652,32 +650,39 @@ if __name__ == "__main__":
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--gpus", type=int, default=None)
parser.add_argument("--max_epochs", type=int, default=None)
parser.add_argument("--precision", type=str, default="32")
parser.add_argument("--log_every_n_steps", type=int, default=50)
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
parser.add_argument("--flush_logs_every_n_steps", type=int, default=5)
parser.add_argument("--num_sanity_val_steps", type=int, default=0)
parser.add_argument("--mpi_plugin", action="store_true", default=False)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser.set_defaults(
num_sanity_val_steps=0,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments(
parser,
[
"--accelerator",
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs",
]
parser.add_argument(
"--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
)
parser.add_argument("--mpi_plugin", action="store_true", default=False,
help="Whether to use MPI for parallele processing")
trainer_group = parser.add_argument_group(
'Arguments to pass to PyTorch Lightning Trainer')
trainer_group.add_argument(
"--num_nodes", type=int, default=1,
)
trainer_group.add_argument(
"--precision", type=str, default='bf16',
help='Sets precision, lower precision improves runtime performance.',
)
trainer_group.add_argument(
"--max_epochs", type=int, default=1,
)
trainer_group.add_argument(
"--log_every_n_steps", type=int, default=25,
)
trainer_group.add_argument(
"--flush_logs_every_n_steps", type=int, default=5,
)
trainer_group.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args()
......@@ -692,7 +697,5 @@ if __name__ == "__main__":
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
# This re-applies the training-time filters at the beginning of every epoch
args.reload_dataloaders_every_n_epochs = 1
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment