Unverified Commit 6e221334 authored by apolinário's avatar apolinário Committed by GitHub
Browse files

[advanced_dreambooth_lora_sdxl_tranining_script] save embeddings locally fix (#6058)



* Update train_dreambooth_lora_sdxl_advanced.py

* remove global function args from dreamboothdataset class

* style

* style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 53bc30dd
...@@ -123,16 +123,26 @@ def save_model_card( ...@@ -123,16 +123,26 @@ def save_model_card(
""" """
trigger_str = f"You should use {instance_prompt} to trigger the image generation." trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
diffusers_example_pivotal = ""
if train_text_encoder_ti: if train_text_encoder_ti:
trigger_str = ( trigger_str = (
"To trigger image generation of trained concept(or concepts) replace each concept identifier " "To trigger image generation of trained concept(or concepts) replace each concept identifier "
"in you prompt with the new inserted tokens:\n" "in you prompt with the new inserted tokens:\n"
) )
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
"""
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id="{repo_id}", filename="embeddings.safetensors", repo_type="model")
state_dict = load_file(embedding_path)
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
"""
if token_abstraction_dict: if token_abstraction_dict:
for key, value in token_abstraction_dict.items(): for key, value in token_abstraction_dict.items():
tokens = "".join(value) tokens = "".join(value)
trigger_str += f""" trigger_str += f"""
to trigger concept `{key}->` use `{tokens}` in your prompt \n to trigger concept `{key}` → use `{tokens}` in your prompt \n
""" """
yaml = f""" yaml = f"""
...@@ -172,7 +182,21 @@ Special VAE used for training: {vae_path}. ...@@ -172,7 +182,21 @@ Special VAE used for training: {vae_path}.
{trigger_str} {trigger_str}
## Download model ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
```py
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## Download model (use it with UIs such as AUTO1111, Comfy, SD.Next, Invoke)
Weights for this model are available in Safetensors format. Weights for this model are available in Safetensors format.
...@@ -791,6 +815,12 @@ class DreamBoothDataset(Dataset): ...@@ -791,6 +815,12 @@ class DreamBoothDataset(Dataset):
instance_data_root, instance_data_root,
instance_prompt, instance_prompt,
class_prompt, class_prompt,
dataset_name,
dataset_config_name,
cache_dir,
image_column,
caption_column,
train_text_encoder_ti,
class_data_root=None, class_data_root=None,
class_num=None, class_num=None,
token_abstraction_dict=None, # token mapping for textual inversion token_abstraction_dict=None, # token mapping for textual inversion
...@@ -805,10 +835,10 @@ class DreamBoothDataset(Dataset): ...@@ -805,10 +835,10 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None self.custom_instance_prompts = None
self.class_prompt = class_prompt self.class_prompt = class_prompt
self.token_abstraction_dict = token_abstraction_dict self.token_abstraction_dict = token_abstraction_dict
self.train_text_encoder_ti = train_text_encoder_ti
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset # we load the training data using load_dataset
if args.dataset_name is not None: if dataset_name is not None:
try: try:
from datasets import load_dataset from datasets import load_dataset
except ImportError: except ImportError:
...@@ -821,26 +851,25 @@ class DreamBoothDataset(Dataset): ...@@ -821,26 +851,25 @@ class DreamBoothDataset(Dataset):
# See more about loading custom images at # See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset( dataset = load_dataset(
args.dataset_name, dataset_name,
args.dataset_config_name, dataset_config_name,
cache_dir=args.cache_dir, cache_dir=cache_dir,
) )
# Preprocessing the datasets. # Preprocessing the datasets.
column_names = dataset["train"].column_names column_names = dataset["train"].column_names
# 6. Get the column names for input/target. # 6. Get the column names for input/target.
if args.image_column is None: if image_column is None:
image_column = column_names[0] image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}") logger.info(f"image column defaulting to {image_column}")
else: else:
image_column = args.image_column
if image_column not in column_names: if image_column not in column_names:
raise ValueError( raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
) )
instance_images = dataset["train"][image_column] instance_images = dataset["train"][image_column]
if args.caption_column is None: if caption_column is None:
logger.info( logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset " "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the " "contains captions/prompts for the images, make sure to specify the "
...@@ -848,11 +877,11 @@ class DreamBoothDataset(Dataset): ...@@ -848,11 +877,11 @@ class DreamBoothDataset(Dataset):
) )
self.custom_instance_prompts = None self.custom_instance_prompts = None
else: else:
if args.caption_column not in column_names: if caption_column not in column_names:
raise ValueError( raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
) )
custom_instance_prompts = dataset["train"][args.caption_column] custom_instance_prompts = dataset["train"][caption_column]
# create final list of captions according to --repeats # create final list of captions according to --repeats
self.custom_instance_prompts = [] self.custom_instance_prompts = []
for caption in custom_instance_prompts: for caption in custom_instance_prompts:
...@@ -907,7 +936,7 @@ class DreamBoothDataset(Dataset): ...@@ -907,7 +936,7 @@ class DreamBoothDataset(Dataset):
if self.custom_instance_prompts: if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images] caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption: if caption:
if args.train_text_encoder_ti: if self.train_text_encoder_ti:
# replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc. # replace instances of --token_abstraction in caption with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in self.token_abstraction_dict.items(): for token_abs, token_replacement in self.token_abstraction_dict.items():
caption = caption.replace(token_abs, "".join(token_replacement)) caption = caption.replace(token_abs, "".join(token_replacement))
...@@ -1093,10 +1122,10 @@ def main(args): ...@@ -1093,10 +1122,10 @@ def main(args):
if args.output_dir is not None: if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
model_id = args.hub_model_id or Path(args.output_dir).name
repo_id = None
if args.push_to_hub: if args.push_to_hub:
repo_id = create_repo( repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# Load the tokenizers # Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained( tokenizer_one = AutoTokenizer.from_pretrained(
...@@ -1464,6 +1493,12 @@ def main(args): ...@@ -1464,6 +1493,12 @@ def main(args):
instance_data_root=args.instance_data_dir, instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt, instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt, class_prompt=args.class_prompt,
dataset_name=args.dataset_name,
dataset_config_name=args.dataset_config_name,
cache_dir=args.cache_dir,
image_column=args.image_column,
train_text_encoder_ti=args.train_text_encoder_ti,
caption_column=args.caption_column,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
class_num=args.num_class_images, class_num=args.num_class_images,
...@@ -2004,13 +2039,12 @@ def main(args): ...@@ -2004,13 +2039,12 @@ def main(args):
} }
) )
if args.push_to_hub:
if args.train_text_encoder_ti: if args.train_text_encoder_ti:
embedding_handler.save_embeddings( embedding_handler.save_embeddings(
f"{args.output_dir}/embeddings.safetensors", f"{args.output_dir}/embeddings.safetensors",
) )
save_model_card( save_model_card(
repo_id, model_id if not args.push_to_hub else repo_id,
images=images, images=images,
base_model=args.pretrained_model_name_or_path, base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder, train_text_encoder=args.train_text_encoder,
...@@ -2021,6 +2055,7 @@ def main(args): ...@@ -2021,6 +2055,7 @@ def main(args):
repo_folder=args.output_dir, repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path, vae_path=args.pretrained_vae_model_name_or_path,
) )
if args.push_to_hub:
upload_folder( upload_folder(
repo_id=repo_id, repo_id=repo_id,
folder_path=args.output_dir, folder_path=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment