Unverified Commit d4b3e359 authored by Zachary Mueller's avatar Zachary Mueller Committed by GitHub
Browse files

Don't push checkpoints to hub in `no_trainer` scripts (#16703)

Adds checkpoint prefixes to the gitignore if `push_to_hub` is used along with `checkpointint_steps`
parent c04619ec
...@@ -39,6 +39,7 @@ from tqdm.auto import tqdm ...@@ -39,6 +39,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator, DistributedType from accelerate import Accelerator, DistributedType
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -50,7 +51,6 @@ from transformers import ( ...@@ -50,7 +51,6 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import get_full_repo_name from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -258,6 +258,12 @@ def main(): ...@@ -258,6 +258,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
...@@ -542,7 +548,6 @@ def main(): ...@@ -542,7 +548,6 @@ def main():
if args.output_dir is not None: if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir) output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir) accelerator.save_state(output_dir)
if completed_steps >= args.max_train_steps: if completed_steps >= args.max_train_steps:
break break
......
...@@ -39,6 +39,7 @@ from tqdm.auto import tqdm ...@@ -39,6 +39,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator, DistributedType from accelerate import Accelerator, DistributedType
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -50,7 +51,6 @@ from transformers import ( ...@@ -50,7 +51,6 @@ from transformers import (
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
SchedulerType, SchedulerType,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import get_full_repo_name from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -269,6 +269,12 @@ def main(): ...@@ -269,6 +269,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -37,6 +37,7 @@ from tqdm.auto import tqdm ...@@ -37,6 +37,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -49,7 +50,6 @@ from transformers import ( ...@@ -49,7 +50,6 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import PaddingStrategy, get_full_repo_name from transformers.utils import PaddingStrategy, get_full_repo_name
...@@ -296,6 +296,12 @@ def main(): ...@@ -296,6 +296,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -34,6 +34,7 @@ from tqdm.auto import tqdm ...@@ -34,6 +34,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
AdamW, AdamW,
...@@ -45,7 +46,6 @@ from transformers import ( ...@@ -45,7 +46,6 @@ from transformers import (
XLNetTokenizerFast, XLNetTokenizerFast,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import check_min_version, get_full_repo_name from transformers.utils import check_min_version, get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -290,6 +290,12 @@ def main(): ...@@ -290,6 +290,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -35,6 +35,7 @@ from tqdm.auto import tqdm ...@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -48,7 +49,6 @@ from transformers import ( ...@@ -48,7 +49,6 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import check_min_version, get_full_repo_name from transformers.utils import check_min_version, get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -320,6 +320,12 @@ def main(): ...@@ -320,6 +320,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -36,6 +36,7 @@ from tqdm.auto import tqdm ...@@ -36,6 +36,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from filelock import FileLock from filelock import FileLock
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
...@@ -48,7 +49,6 @@ from transformers import ( ...@@ -48,7 +49,6 @@ from transformers import (
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
SchedulerType, SchedulerType,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import get_full_repo_name, is_offline_mode from transformers.utils import get_full_repo_name, is_offline_mode
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -346,6 +346,12 @@ def main(): ...@@ -346,6 +346,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -28,6 +28,7 @@ from tqdm.auto import tqdm ...@@ -28,6 +28,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
AdamW, AdamW,
...@@ -39,7 +40,6 @@ from transformers import ( ...@@ -39,7 +40,6 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import get_full_repo_name from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -223,6 +223,12 @@ def main(): ...@@ -223,6 +223,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -34,6 +34,7 @@ from tqdm.auto import tqdm ...@@ -34,6 +34,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -47,7 +48,6 @@ from transformers import ( ...@@ -47,7 +48,6 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import get_full_repo_name from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -277,6 +277,12 @@ def main(): ...@@ -277,6 +277,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
...@@ -35,6 +35,7 @@ from tqdm.auto import tqdm ...@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed
from huggingface_hub import Repository from huggingface_hub import Repository
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
...@@ -49,7 +50,6 @@ from transformers import ( ...@@ -49,7 +50,6 @@ from transformers import (
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
set_seed,
) )
from transformers.utils import get_full_repo_name from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
...@@ -319,6 +319,12 @@ def main(): ...@@ -319,6 +319,12 @@ def main():
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None: elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment