Unverified Commit 5ea4be86 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Create repo before cloning in examples (#2047)

* Create repo before cloning in examples

* code quality
parent e5ff7554
...@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon ...@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -551,7 +551,8 @@ def main(args): ...@@ -551,7 +551,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version ...@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from jax.experimental.compilation_cache import compilation_cache as cc from jax.experimental.compilation_cache import compilation_cache as cc
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
...@@ -387,7 +387,8 @@ def main(): ...@@ -387,7 +387,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -599,8 +599,8 @@ def main(args): ...@@ -599,8 +599,8 @@ def main(args):
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo_name = create_repo(repo_name, exist_ok=True) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name) repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -20,7 +20,7 @@ from colossalai.utils import get_current_device ...@@ -20,7 +20,7 @@ from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -420,7 +420,8 @@ def main(args): ...@@ -420,7 +420,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -25,7 +25,7 @@ from diffusers import ( ...@@ -25,7 +25,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -471,7 +471,8 @@ def main(): ...@@ -471,7 +471,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -21,7 +21,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi ...@@ -21,7 +21,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -393,7 +393,8 @@ def main(): ...@@ -393,7 +393,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon ...@@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -570,7 +570,8 @@ def main(args): ...@@ -570,7 +570,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -38,7 +38,7 @@ from diffusers.optimization import get_scheduler ...@@ -38,7 +38,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
...@@ -343,7 +343,8 @@ def main(): ...@@ -343,7 +343,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version ...@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
...@@ -255,7 +255,8 @@ def main(): ...@@ -255,7 +255,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe ...@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -464,7 +464,8 @@ def main(): ...@@ -464,7 +464,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -28,7 +28,7 @@ from diffusers.utils import check_min_version ...@@ -28,7 +28,7 @@ from diffusers.utils import check_min_version
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -372,7 +372,8 @@ def main(): ...@@ -372,7 +372,8 @@ def main():
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -19,7 +19,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel ...@@ -19,7 +19,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
...@@ -287,7 +287,8 @@ def main(args): ...@@ -287,7 +287,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
...@@ -15,7 +15,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel ...@@ -15,7 +15,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
...@@ -371,7 +371,8 @@ def main(args): ...@@ -371,7 +371,8 @@ def main(args):
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else: else:
repo_name = args.hub_model_id repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name) create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore: if "step_*" not in gitignore:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment