"git@developer.sourcefind.cn:OpenDAS/Uni-Core.git" did not exist on "4e4dc4b80d76037928491ac389114b3ab3c16284"
Commit 3a65a7c9 authored by Marta's avatar Marta
Browse files

fix bool args parsing bug

parent cd5a1e13
...@@ -185,6 +185,16 @@ def main(args): ...@@ -185,6 +185,16 @@ def main(args):
) )
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -245,7 +255,7 @@ if __name__ == "__main__": ...@@ -245,7 +255,7 @@ if __name__ == "__main__":
files.""" files."""
) )
parser.add_argument( parser.add_argument(
"--use_small_bfd", type=bool, default=False, "--use_small_bfd", type=bool_type, default=False,
help="Whether to use a reduced version of the BFD database" help="Whether to use a reduced version of the BFD database"
) )
parser.add_argument( parser.add_argument(
...@@ -257,12 +267,12 @@ if __name__ == "__main__": ...@@ -257,12 +267,12 @@ if __name__ == "__main__":
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled" help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
) )
parser.add_argument( parser.add_argument(
"--checkpoint_best_val", type=bool, default=True, "--checkpoint_best_val", type=bool_type, default=True,
help="""Whether to save the model parameters that perform best during help="""Whether to save the model parameters that perform best during
validation""" validation"""
) )
parser.add_argument( parser.add_argument(
"--early_stopping", type=bool, default=False, "--early_stopping", type=bool_type, default=False,
help="Whether to stop training when validation loss fails to decrease" help="Whether to stop training when validation loss fails to decrease"
) )
parser.add_argument( parser.add_argument(
...@@ -279,7 +289,7 @@ if __name__ == "__main__": ...@@ -279,7 +289,7 @@ if __name__ == "__main__":
help="Path to a model checkpoint from which to restore training state" help="Path to a model checkpoint from which to restore training state"
) )
parser.add_argument( parser.add_argument(
"--resume_model_weights_only", type=bool, default=False, "--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state" help="Whether to load just model weights as opposed to training state"
) )
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment