Commit 1ae833bf authored by Jennifer's avatar Jennifer
Browse files

Updates low_precision check to use current precision settings.

parent 80e63410
...@@ -287,10 +287,13 @@ def main(args): ...@@ -287,10 +287,13 @@ def main(args):
if (args.seed is not None): if (args.seed is not None):
seed_everything(args.seed, workers=True) seed_everything(args.seed, workers=True)
is_low_precision = args.precision in [
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
config = model_config( config = model_config(
args.config_preset, args.config_preset,
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=is_low_precision,
) )
if args.experiment_config_json: if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f: with open(args.experiment_config_json, 'r') as f:
...@@ -643,7 +646,8 @@ if __name__ == "__main__": ...@@ -643,7 +646,8 @@ if __name__ == "__main__":
"--num_nodes", type=int, default=1, "--num_nodes", type=int, default=1,
) )
trainer_group.add_argument( trainer_group.add_argument(
"--precision", type=str, default='bf16', help='Sets precision, lower precision improves runtime performance.' "--precision", type=str, default='bf16',
help='Sets precision, lower precision improves runtime performance.',
) )
trainer_group.add_argument( trainer_group.add_argument(
"--max_epochs", type=int, default=1, "--max_epochs", type=int, default=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment