"docs/zh_cn/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "11d8554cf93b62f72d94996fc7fed62748db722f"
Unverified Commit a87e88b7 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Use `upload_folder` in training scripts (#2934)



use upload folder in training scripts
Co-authored-by: default avatartestbot <lucainp@hf.co>
parent a0263b2e
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import accelerate import accelerate
import numpy as np import numpy as np
...@@ -31,7 +30,7 @@ from accelerate import Accelerator ...@@ -31,7 +30,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
...@@ -661,16 +660,6 @@ def collate_fn(examples): ...@@ -661,16 +660,6 @@ def collate_fn(examples):
} }
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -704,22 +693,14 @@ def main(args): ...@@ -704,22 +693,14 @@ def main(args):
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
...@@ -1053,7 +1034,12 @@ def main(args): ...@@ -1053,7 +1034,12 @@ def main(args):
controlnet.save_pretrained(args.output_dir) controlnet.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -33,7 +32,7 @@ from flax import jax_utils ...@@ -33,7 +32,7 @@ from flax import jax_utils
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import unfreeze
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from PIL import Image from PIL import Image
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
from torchvision import transforms from torchvision import transforms
...@@ -148,7 +147,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d ...@@ -148,7 +147,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
return image_logs return image_logs
def save_model_card(repo_name, image_logs=None, base_model=str, repo_folder=None): def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = "" img_str = ""
for i, log in enumerate(image_logs): for i, log in enumerate(image_logs):
images = log["images"] images = log["images"]
...@@ -174,7 +173,7 @@ inference: true ...@@ -174,7 +173,7 @@ inference: true
--- ---
""" """
model_card = f""" model_card = f"""
# controlnet- {repo_name} # controlnet- {repo_id}
These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n These are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \n
{img_str} {img_str}
...@@ -612,16 +611,6 @@ def collate_fn(examples): ...@@ -612,16 +611,6 @@ def collate_fn(examples):
return batch return batch
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def get_params_to_save(params): def get_params_to_save(params):
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
...@@ -656,22 +645,14 @@ def main(): ...@@ -656,22 +645,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if jax.process_index() == 0: if jax.process_index() == 0:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo_url = create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_url, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer and add the placeholder token as a additional special token # Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -1020,12 +1001,17 @@ def main(): ...@@ -1020,12 +1001,17 @@ def main():
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
repo_name, repo_id,
image_logs=image_logs, image_logs=image_logs,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,7 +21,6 @@ import math ...@@ -21,7 +21,6 @@ import math
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional
import accelerate import accelerate
import numpy as np import numpy as np
...@@ -32,7 +31,7 @@ import transformers ...@@ -32,7 +31,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -575,16 +574,6 @@ class PromptDataset(Dataset): ...@@ -575,16 +574,6 @@ class PromptDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -677,22 +666,14 @@ def main(args): ...@@ -677,22 +666,14 @@ def main(args):
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
...@@ -1043,7 +1024,12 @@ def main(args): ...@@ -1043,7 +1024,12 @@ def main(args):
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -20,7 +20,6 @@ import math ...@@ -20,7 +20,6 @@ import math
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -30,7 +29,7 @@ import transformers ...@@ -30,7 +29,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -59,7 +58,7 @@ check_min_version("0.15.0.dev0") ...@@ -59,7 +58,7 @@ check_min_version("0.15.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None): def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
...@@ -80,7 +79,7 @@ inference: true ...@@ -80,7 +79,7 @@ inference: true
--- ---
""" """
model_card = f""" model_card = f"""
# LoRA DreamBooth - {repo_name} # LoRA DreamBooth - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str} {img_str}
...@@ -528,16 +527,6 @@ class PromptDataset(Dataset): ...@@ -528,16 +527,6 @@ class PromptDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -625,23 +614,14 @@ def main(args): ...@@ -625,23 +614,14 @@ def main(args):
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
...@@ -1027,13 +1007,18 @@ def main(args): ...@@ -1027,13 +1007,18 @@ def main(args):
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
repo_name, repo_id,
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
prompt=args.instance_prompt, prompt=args.instance_prompt,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -21,7 +21,6 @@ import logging ...@@ -21,7 +21,6 @@ import logging
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
import accelerate import accelerate
import datasets import datasets
...@@ -37,7 +36,7 @@ from accelerate import Accelerator ...@@ -37,7 +36,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -363,16 +362,6 @@ def parse_args(): ...@@ -363,16 +362,6 @@ def parse_args():
return args return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def convert_to_np(image, resolution): def convert_to_np(image, resolution):
image = image.convert("RGB").resize((resolution, resolution)) image = image.convert("RGB").resize((resolution, resolution))
return np.array(image).transpose(2, 0, 1) return np.array(image).transpose(2, 0, 1)
...@@ -436,22 +425,14 @@ def main(): ...@@ -436,22 +425,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models. # Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
...@@ -968,7 +949,12 @@ def main(): ...@@ -968,7 +949,12 @@ def main():
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
if args.validation_prompt is not None: if args.validation_prompt is not None:
edited_images = [] edited_images = []
......
...@@ -3,7 +3,6 @@ import hashlib ...@@ -3,7 +3,6 @@ import hashlib
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
import colossalai import colossalai
import torch import torch
...@@ -16,7 +15,7 @@ from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer ...@@ -16,7 +15,7 @@ from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
...@@ -344,16 +343,6 @@ class PromptDataset(Dataset): ...@@ -344,16 +343,6 @@ class PromptDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
...@@ -413,22 +402,14 @@ def main(args): ...@@ -413,22 +402,14 @@ def main(args):
# Handle the repository creation # Handle the repository creation
if local_rank == 0: if local_rank == 0:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0]) logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])
...@@ -679,7 +660,12 @@ def main(args): ...@@ -679,7 +660,12 @@ def main(args):
logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0]) logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,7 +5,6 @@ import math ...@@ -5,7 +5,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -14,7 +13,7 @@ import torch.utils.checkpoint ...@@ -14,7 +13,7 @@ import torch.utils.checkpoint
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
...@@ -402,16 +401,6 @@ class PromptDataset(Dataset): ...@@ -402,16 +401,6 @@ class PromptDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -485,22 +474,14 @@ def main(): ...@@ -485,22 +474,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -816,7 +797,12 @@ def main(): ...@@ -816,7 +797,12 @@ def main():
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -4,7 +4,6 @@ import math ...@@ -4,7 +4,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -13,7 +12,7 @@ import torch.utils.checkpoint ...@@ -13,7 +12,7 @@ import torch.utils.checkpoint
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
...@@ -401,16 +400,6 @@ class PromptDataset(Dataset): ...@@ -401,16 +400,6 @@ class PromptDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -484,22 +473,14 @@ def main(): ...@@ -484,22 +473,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -835,7 +816,12 @@ def main(): ...@@ -835,7 +816,12 @@ def main():
unet.save_attn_procs(args.output_dir) unet.save_attn_procs(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -4,7 +4,6 @@ import math ...@@ -4,7 +4,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import numpy as np import numpy as np
...@@ -15,7 +14,7 @@ import torch.utils.checkpoint ...@@ -15,7 +14,7 @@ import torch.utils.checkpoint
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import set_seed from accelerate.utils import set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -356,16 +355,6 @@ class TextualInversionDataset(Dataset): ...@@ -356,16 +355,6 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def freeze_params(params): def freeze_params(params):
for param in params: for param in params:
param.requires_grad = False param.requires_grad = False
...@@ -388,22 +377,14 @@ def main(): ...@@ -388,22 +377,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer and add the placeholder token as a additional special token # Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -640,7 +621,12 @@ def main(): ...@@ -640,7 +621,12 @@ def main():
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -22,7 +22,6 @@ import math ...@@ -22,7 +22,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
...@@ -34,7 +33,7 @@ from accelerate import Accelerator ...@@ -34,7 +33,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -55,7 +54,7 @@ check_min_version("0.14.0.dev0") ...@@ -55,7 +54,7 @@ check_min_version("0.14.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None): def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
...@@ -75,7 +74,7 @@ inference: true ...@@ -75,7 +74,7 @@ inference: true
--- ---
""" """
model_card = f""" model_card = f"""
# LoRA text2image fine-tuning - {repo_name} # LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str} {img_str}
""" """
...@@ -386,16 +385,6 @@ def parse_args(): ...@@ -386,16 +385,6 @@ def parse_args():
return args return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
DATASET_NAME_MAPPING = { DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"), "lambdalabs/pokemon-blip-captions": ("image", "text"),
} }
...@@ -441,22 +430,14 @@ def main(): ...@@ -441,22 +430,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo_name = create_repo(repo_name, exist_ok=True)
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models. # Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
...@@ -945,13 +926,18 @@ def main(): ...@@ -945,13 +926,18 @@ def main():
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
repo_name, repo_id,
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name, dataset_name=args.dataset_name,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import PIL import PIL
...@@ -30,7 +29,7 @@ import transformers ...@@ -30,7 +29,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from multi_token_clip import MultiTokenCLIPTokenizer from multi_token_clip import MultiTokenCLIPTokenizer
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
...@@ -547,16 +546,6 @@ class TextualInversionDataset(Dataset): ...@@ -547,16 +546,6 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
...@@ -596,22 +585,14 @@ def main(): ...@@ -596,22 +585,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load tokenizer # Load tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -932,7 +913,12 @@ def main(): ...@@ -932,7 +913,12 @@ def main():
save_progress(tokenizer, text_encoder, accelerator, save_path) save_progress(tokenizer, text_encoder, accelerator, save_path)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -4,7 +4,6 @@ import math ...@@ -4,7 +4,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -17,7 +16,7 @@ import transformers ...@@ -17,7 +16,7 @@ import transformers
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -326,16 +325,6 @@ class TextualInversionDataset(Dataset): ...@@ -326,16 +325,6 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng): def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
if model.config.vocab_size == new_num_tokens or new_num_tokens is None: if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
return return
...@@ -367,22 +356,14 @@ def main(): ...@@ -367,22 +356,14 @@ def main():
set_seed(args.seed) set_seed(args.seed)
if jax.process_index() == 0: if jax.process_index() == 0:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -661,7 +642,12 @@ def main(): ...@@ -661,7 +642,12 @@ def main():
jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -6,7 +6,6 @@ import math ...@@ -6,7 +6,6 @@ import math
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional
import datasets import datasets
import torch import torch
...@@ -16,7 +15,7 @@ import transformers ...@@ -16,7 +15,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import transforms from torchvision import transforms
...@@ -463,16 +462,6 @@ class PromptDataset(Dataset): ...@@ -463,16 +462,6 @@ class PromptDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -584,22 +573,14 @@ def main(args): ...@@ -584,22 +573,14 @@ def main(args):
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizer # Load the tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
...@@ -886,7 +867,12 @@ def main(args): ...@@ -886,7 +867,12 @@ def main(args):
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
...@@ -31,7 +30,7 @@ from accelerate import Accelerator ...@@ -31,7 +30,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -313,16 +312,6 @@ def parse_args(): ...@@ -313,16 +312,6 @@ def parse_args():
return args return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
dataset_name_mapping = { dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"), "lambdalabs/pokemon-blip-captions": ("image", "text"),
} }
...@@ -364,22 +353,14 @@ def main(): ...@@ -364,22 +353,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models. # Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
...@@ -732,7 +713,12 @@ def main(): ...@@ -732,7 +713,12 @@ def main():
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
...@@ -31,7 +30,7 @@ import transformers ...@@ -31,7 +30,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from onnxruntime.training.ortmodule import ORTModule from onnxruntime.training.ortmodule import ORTModule
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
...@@ -463,16 +462,6 @@ class TextualInversionDataset(Dataset): ...@@ -463,16 +462,6 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
...@@ -514,22 +503,14 @@ def main(): ...@@ -514,22 +503,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load tokenizer # Load tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -851,7 +832,12 @@ def main(): ...@@ -851,7 +832,12 @@ def main():
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import accelerate import accelerate
import datasets import datasets
...@@ -32,7 +31,7 @@ from accelerate import Accelerator ...@@ -32,7 +31,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -315,16 +314,6 @@ def parse_args(): ...@@ -315,16 +314,6 @@ def parse_args():
return args return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
dataset_name_mapping = { dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"), "lambdalabs/pokemon-blip-captions": ("image", "text"),
} }
...@@ -376,22 +365,14 @@ def main(): ...@@ -376,22 +365,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models. # Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
...@@ -786,7 +767,12 @@ def main(): ...@@ -786,7 +767,12 @@ def main():
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -4,7 +4,6 @@ import math ...@@ -4,7 +4,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -17,7 +16,7 @@ from datasets import load_dataset ...@@ -17,7 +16,7 @@ from datasets import load_dataset
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
...@@ -222,16 +221,6 @@ def parse_args(): ...@@ -222,16 +221,6 @@ def parse_args():
return args return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
dataset_name_mapping = { dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"), "lambdalabs/pokemon-blip-captions": ("image", "text"),
} }
...@@ -261,22 +250,14 @@ def main(): ...@@ -261,22 +250,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if jax.process_index() == 0: if jax.process_index() == 0:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Get the datasets: you can either provide your own training and evaluation files (see below) # Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
...@@ -581,7 +562,12 @@ def main(): ...@@ -581,7 +562,12 @@ def main():
) )
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -20,7 +20,6 @@ import math ...@@ -20,7 +20,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
...@@ -32,7 +31,7 @@ from accelerate import Accelerator ...@@ -32,7 +31,7 @@ from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -53,7 +52,7 @@ check_min_version("0.15.0.dev0") ...@@ -53,7 +52,7 @@ check_min_version("0.15.0.dev0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")
def save_model_card(repo_name, images=None, base_model=str, dataset_name=str, repo_folder=None): def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = "" img_str = ""
for i, image in enumerate(images): for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png")) image.save(os.path.join(repo_folder, f"image_{i}.png"))
...@@ -73,7 +72,7 @@ inference: true ...@@ -73,7 +72,7 @@ inference: true
--- ---
""" """
model_card = f""" model_card = f"""
# LoRA text2image fine-tuning - {repo_name} # LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str} {img_str}
""" """
...@@ -347,16 +346,6 @@ def parse_args(): ...@@ -347,16 +346,6 @@ def parse_args():
return args return args
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
DATASET_NAME_MAPPING = { DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"), "lambdalabs/pokemon-blip-captions": ("image", "text"),
} }
...@@ -402,22 +391,13 @@ def main(): ...@@ -402,22 +391,13 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo_name = create_repo(repo_name, exist_ok=True)
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load scheduler, tokenizer and models. # Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
...@@ -830,13 +810,18 @@ def main(): ...@@ -830,13 +810,18 @@ def main():
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
repo_name, repo_id,
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name, dataset_name=args.dataset_name,
repo_folder=args.output_dir, repo_folder=args.output_dir,
) )
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
......
...@@ -20,7 +20,6 @@ import os ...@@ -20,7 +20,6 @@ import os
import random import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import PIL import PIL
...@@ -31,7 +30,7 @@ import transformers ...@@ -31,7 +30,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -519,16 +518,6 @@ class TextualInversionDataset(Dataset): ...@@ -519,16 +518,6 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(): def main():
args = parse_args() args = parse_args()
logging_dir = os.path.join(args.output_dir, args.logging_dir) logging_dir = os.path.join(args.output_dir, args.logging_dir)
...@@ -567,22 +556,14 @@ def main(): ...@@ -567,22 +556,14 @@ def main():
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load tokenizer # Load tokenizer
if args.tokenizer_name: if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
...@@ -880,7 +861,12 @@ def main(): ...@@ -880,7 +861,12 @@ def main():
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training() accelerator.end_training()
......
...@@ -4,7 +4,6 @@ import math ...@@ -4,7 +4,6 @@ import math
import os import os
import random import random
from pathlib import Path from pathlib import Path
from typing import Optional
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -17,7 +16,7 @@ import transformers ...@@ -17,7 +16,7 @@ import transformers
from flax import jax_utils from flax import jax_utils
from flax.training import train_state from flax.training import train_state
from flax.training.common_utils import shard from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami from huggingface_hub import create_repo, upload_folder
# TODO: remove and import from diffusers.utils when the new version of diffusers is released # TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version from packaging import version
...@@ -339,16 +338,6 @@ class TextualInversionDataset(Dataset): ...@@ -339,16 +338,6 @@ class TextualInversionDataset(Dataset):
return example return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng): def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
if model.config.vocab_size == new_num_tokens or new_num_tokens is None: if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
return return
...@@ -380,22 +369,14 @@ def main(): ...@@ -380,22 +369,14 @@ def main():
set_seed(args.seed) set_seed(args.seed)
if jax.process_index() == 0: if jax.process_index() == 0:
if args.push_to_hub: if args.output_dir is not None:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -688,7 +669,12 @@ def main(): ...@@ -688,7 +669,12 @@ def main():
jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict) jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment