Unverified Commit 26ef0f19 authored by Douglas Trajano's avatar Douglas Trajano Committed by GitHub
Browse files

fix: Race Condition when using Sagemaker Checkpointing and Model Repository (#21614)

* Add _add_sm_patterns_to_gitignore

* Add _is_world_process_zero() call and move patterns arg to constant

* Update git status time.sleep

* Apply make style
parent 7bce8042
...@@ -3395,6 +3395,10 @@ class Trainer: ...@@ -3395,6 +3395,10 @@ class Trainer:
with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"]) writer.writelines(["checkpoint-*/"])
# Add "*.sagemaker" to .gitignore if using SageMaker
if os.environ.get("SM_TRAINING_ENV"):
self._add_sm_patterns_to_gitignore()
self.push_in_progress = None self.push_in_progress = None
def create_model_card( def create_model_card(
...@@ -3716,3 +3720,42 @@ class Trainer: ...@@ -3716,3 +3720,42 @@ class Trainer:
tensors = distributed_concat(tensors) tensors = distributed_concat(tensors)
return nested_numpify(tensors) return nested_numpify(tensors)
def _add_sm_patterns_to_gitignore(self) -> None:
"""Add SageMaker Checkpointing patterns to .gitignore file."""
# Make sure we only do this on the main process
if not self.is_world_process_zero():
return
patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
# Get current .gitignore content
if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
current_content = f.read()
else:
current_content = ""
# Add the patterns to .gitignore
content = current_content
for pattern in patterns:
if pattern not in content:
if content.endswith("\n"):
content += pattern
else:
content += f"\n{pattern}"
# Write the .gitignore file if it has changed
if content != current_content:
with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
logger.debug(f"Writing .gitignore file. Content: {content}")
f.write(content)
self.repo.git_add(".gitignore")
# avoid race condition with git status
time.sleep(0.5)
if not self.repo.is_repo_clean():
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
self.repo.git_push()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment