Commit 3a65a7c9 authored by Marta's avatar Marta
Browse files

fix bool args parsing bug

parent cd5a1e13
......@@ -185,6 +185,16 @@ def main(args):
)
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
......@@ -245,7 +255,7 @@ if __name__ == "__main__":
files."""
)
parser.add_argument(
"--use_small_bfd", type=bool, default=False,
"--use_small_bfd", type=bool_type, default=False,
help="Whether to use a reduced version of the BFD database"
)
parser.add_argument(
......@@ -257,12 +267,12 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_best_val", type=bool, default=True,
"--checkpoint_best_val", type=bool_type, default=True,
help="""Whether to save the model parameters that perform best during
validation"""
)
parser.add_argument(
"--early_stopping", type=bool, default=False,
"--early_stopping", type=bool_type, default=False,
help="Whether to stop training when validation loss fails to decrease"
)
parser.add_argument(
......@@ -279,7 +289,7 @@ if __name__ == "__main__":
help="Path to a model checkpoint from which to restore training state"
)
parser.add_argument(
"--resume_model_weights_only", type=bool, default=False,
"--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment