Unverified Commit d3d22ce5 authored by Brian Whicheloe's avatar Brian Whicheloe Committed by GitHub
Browse files

Small modification to enable usage by external scripts (#956)



* Make training code usable by external scripts

Add parameter inputs to training and argument parsing function to allow this script to be used by an external call.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 8332c1a6
...@@ -26,7 +26,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -26,7 +26,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_args(): def parse_args(input_args):
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument( parser.add_argument(
"--pretrained_model_name_or_path", "--pretrained_model_name_or_path",
...@@ -196,7 +196,11 @@ def parse_args(): ...@@ -196,7 +196,11 @@ def parse_args():
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args() if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank: if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank args.local_rank = env_local_rank
...@@ -319,8 +323,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -319,8 +323,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
def main(): def main(args):
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator( accelerator = Accelerator(
...@@ -653,4 +656,5 @@ def main(): ...@@ -653,4 +656,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() args = parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment